summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal/generator.go
blob: 28ae6c4ef0a8b25ace88ae441b938c049ecaad6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package gomarshal implements the go_marshal code generator. See README.md.
package gomarshal

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"os"
	"sort"
	"strings"

	"gvisor.dev/gvisor/tools/tags"
)

// List of identifiers we use in generated code that may conflict with a
// similarly-named source identifier. Abort gracefully when we see these to
// avoid potentially confusing compilation failures in generated code.
//
// This only applies to import aliases at the moment. All other identifiers
// are qualified by a receiver argument, since they're struct fields.
//
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
	"addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx",
	"inner", "length", "limit", "ptr", "size", "src", "srcs", "val",
	// All single-letter identifiers.
}

// Constructed fromt badIdents in init().
var badIdentsMap map[string]struct{}

func init() {
	badIdentsMap = make(map[string]struct{})
	for _, ident := range badIdents {
		badIdentsMap[ident] = struct{}{}
	}
}

// Generator drives code generation for a single invocation of the go_marshal
// utility.
//
// The Generator holds arguments passed to the tool, and drives parsing,
// processing and code Generator for all types marked with +marshal declared in
// the input files.
//
// See Generator.run() as the entry point.
type Generator struct {
	// Paths to input go source files.
	inputs []string
	// Output file to write generated go source.
	output *os.File
	// Output file to write generated tests.
	outputTest *os.File
	// Output file to write unconditionally generated tests.
	outputTestUC *os.File
	// Package name for the generated file.
	pkg string
	// Set of extra packages to import in the generated file.
	imports *importTable
}

// NewGenerator creates a new code Generator.
func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) {
	f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
	if err != nil {
		return nil, fmt.Errorf("couldn't open output file %q: %w", out, err)
	}
	fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
	if err != nil {
		return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err)
	}
	fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
	if err != nil {
		return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err)
	}
	g := Generator{
		inputs:       srcs,
		output:       f,
		outputTest:   fTest,
		outputTestUC: fTestUC,
		pkg:          pkg,
		imports:      newImportTable(),
	}
	for _, i := range imports {
		// All imports on the extra imports list are unconditionally marked as
		// used, so that they're always added to the generated code.
		g.imports.add(i).markUsed()
	}

	// The following imports may or may not be used by the generated code,
	// depending on what's required for the target types. Don't mark these as
	// used by default.
	g.imports.add("io")
	g.imports.add("reflect")
	g.imports.add("runtime")
	g.imports.add("unsafe")
	g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
	g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
	g.imports.add("gvisor.dev/gvisor/pkg/usermem")
	g.imports.add("gvisor.dev/gvisor/pkg/marshal")

	return &g, nil
}

// writeHeader writes the header for the generated source file. The header
// includes the package name, package level comments and import statements.
func (g *Generator) writeHeader() error {
	var b sourceBuffer
	b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")

	// Emit build tags.
	if t := tags.Aggregate(g.inputs); len(t) > 0 {
		b.emit(strings.Join(t.Lines(), "\n"))
		b.emit("\n\n")
	}

	// Package header.
	b.emit("package %s\n\n", g.pkg)
	if err := b.write(g.output); err != nil {
		return err
	}

	return g.imports.write(g.output)
}

// writeTypeChecks writes a statement to force the compiler to perform a type
// check for all Marshallable types referenced by the generated code.
func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
	if len(ms) == 0 {
		return nil
	}

	msl := make([]string, 0, len(ms))
	for m, _ := range ms {
		msl = append(msl, m)
	}
	sort.Strings(msl)

	var buf bytes.Buffer
	fmt.Fprint(&buf, "// Marshallable types used by this file.\n")

	for _, m := range msl {
		fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
	}
	fmt.Fprint(&buf, "\n")

	_, err := fmt.Fprint(g.output, buf.String())
	return err
}

// parse processes all input files passed this generator and produces a set of
// parsed go ASTs.
func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
	debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
	for _, path := range g.inputs {
		debugf("  %s\n", path)
	}

	files := make([]*ast.File, 0, len(g.inputs))
	fsets := make([]*token.FileSet, 0, len(g.inputs))

	for _, path := range g.inputs {
		fset := token.NewFileSet()
		f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
		if err != nil {
			// Not a valid input file?
			return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err)
		}

		if debugEnabled() {
			debugf("AST for %q:\n", path)
			ast.Print(fset, f)
		}

		files = append(files, f)
		fsets = append(fsets, fset)
	}

	return files, fsets, nil
}

// sliceAPI carries information about the '+marshal slice' directive.
type sliceAPI struct {
	// Comment node in the AST containing the +marshal tag.
	comment *ast.Comment
	// Identifier fragment to use when naming generated functions for the slice
	// API.
	ident string
	// Whether the generated functions should reference the newtype name, or the
	// inner type name. Only meaningful on newtype declarations on primitives.
	inner bool
}

// marshallableType carries information about a type marked with the '+marshal'
// directive.
type marshallableType struct {
	spec  *ast.TypeSpec
	slice *sliceAPI
	recv  string
}

func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
	mt := &marshallableType{
		spec:  spec,
		slice: nil,
	}

	var unhandledTags []string

	for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
		if strings.HasPrefix(tag, "slice:") {
			tokens := strings.Split(tag, ":")
			if len(tokens) < 2 || len(tokens) > 3 {
				abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
			}
			if len(tokens[1]) == 0 {
				abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
			}

			sa := &sliceAPI{
				comment: tagLine,
				ident:   tokens[1],
			}
			mt.slice = sa

			if len(tokens) == 3 {
				if tokens[2] != "inner" {
					abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
				}
				sa.inner = true
			}

			continue
		}

		unhandledTags = append(unhandledTags, tag)
	}

	if len(unhandledTags) > 0 {
		abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
	}

	return mt
}

// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType {
	recv := make(map[string]string) // Type name to recevier name.
	types := make(map[*ast.TypeSpec]*marshallableType)
	for _, decl := range a.Decls {
		gdecl, ok := decl.(*ast.GenDecl)
		// Type declaration?
		if !ok || gdecl.Tok != token.TYPE {
			// Is this a function declaration? We remember receiver names.
			d, ok := decl.(*ast.FuncDecl)
			if ok && d.Recv != nil && len(d.Recv.List) == 1 {
				// Accept concrete methods & pointer methods.
				ident, ok := d.Recv.List[0].Type.(*ast.Ident)
				if !ok {
					var st *ast.StarExpr
					st, ok = d.Recv.List[0].Type.(*ast.StarExpr)
					if ok {
						ident, ok = st.X.(*ast.Ident)
					}
				}
				// The receiver name may be not present.
				if ok && len(d.Recv.List[0].Names) == 1 {
					// Recover the type receiver name in this case.
					recv[ident.Name] = d.Recv.List[0].Names[0].Name
				}
			}
			debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
			continue
		}
		// Does it have a comment?
		if gdecl.Doc == nil {
			debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
			continue
		}
		// Does the comment contain a "+marshal" line?
		marked := false
		var tagLine *ast.Comment
		for _, c := range gdecl.Doc.List {
			if strings.HasPrefix(c.Text, "// +marshal") {
				marked = true
				tagLine = c
				break
			}
		}
		if !marked {
			debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
			continue
		}
		for _, spec := range gdecl.Specs {
			// We already confirmed we're in a type declaration earlier, so this
			// cast will succeed.
			t := spec.(*ast.TypeSpec)
			switch t.Type.(type) {
			case *ast.StructType:
				debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
			case *ast.Ident: // Newtype on primitive.
				debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
			case *ast.ArrayType: // Newtype on array.
				debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
			default:
				// A user specifically requested marshalling on this type, but we
				// don't support it.
				abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
			}
			types[t] = newMarshallableType(f, tagLine, t)
		}
	}
	// Update the types with the last seen receiver. As long as the
	// receiver name is consistent for the type, then we will generate
	// code that is still consistent with itself.
	for t, mt := range types {
		r, ok := recv[t.Name.Name]
		if !ok {
			mt.recv = receiverName(t) // Default.
			continue
		}
		mt.recv = r // Last seen.
	}
	return types
}

// collectImports collects all imports from all input source files. Some of
// these imports are copied to the generated output, if they're referenced by
// the generated code.
//
// collectImports de-duplicates imports while building the list, and ensures
// identifiers in the generated code don't conflict with any imported package
// names.
func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
	is := make(map[string]importStmt)
	for _, decl := range a.Decls {
		gdecl, ok := decl.(*ast.GenDecl)
		// Import statement?
		if !ok || gdecl.Tok != token.IMPORT {
			continue
		}
		for _, spec := range gdecl.Specs {
			i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
			debugf("Collected import '%s' as '%s'\n", i.path, i.name)

			// Make sure we have an import that doesn't use any local names that
			// would conflict with identifiers in the generated code.
			if len(i.name) == 1 && i.name != "_" {
				abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
			}
			if _, ok := badIdentsMap[i.name]; ok {
				abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
			}
		}
	}
	return is

}

func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
	i := newInterfaceGenerator(t.spec, t.recv, fset)
	switch ty := t.spec.Type.(type) {
	case *ast.StructType:
		i.validateStruct(t.spec, ty)
		i.emitMarshallableForStruct(ty)
		if t.slice != nil {
			i.emitMarshallableSliceForStruct(ty, t.slice)
		}
	case *ast.Ident:
		i.validatePrimitiveNewtype(ty)
		i.emitMarshallableForPrimitiveNewtype(ty)
		if t.slice != nil {
			i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
		}
	case *ast.ArrayType:
		i.validateArrayNewtype(t.spec.Name, ty)
		// After validate, we can safely call arrayLen.
		i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
		if t.slice != nil {
			abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
		}
	default:
		// This should've been filtered out by collectMarshallabeTypes.
		panic(fmt.Sprintf("Unexpected type %+v", ty))
	}
	return i
}

// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
	i := newTestGenerator(t.spec, t.recv)
	i.emitTests(t.slice)
	return i
}

// Run is the entry point to code generation using g.
//
// Run parses all input source files specified in g and emits generated code.
func (g *Generator) Run() error {
	// Parse our input source files into ASTs and token sets.
	asts, fsets, err := g.parse()
	if err != nil {
		return err
	}

	if len(asts) != len(fsets) {
		panic("ASTs and FileSets don't match")
	}

	// Map of imports in source files; key = local package name, value = import
	// path.
	is := make(map[string]importStmt)
	for i, a := range asts {
		// Collect all imports from the source files. We may need to copy some
		// of these to the generated code if they're referenced. This has to be
		// done before the loop below because we need to process all ASTs before
		// we start requesting imports to be copied one by one as we encounter
		// them in each generated source.
		for name, i := range g.collectImports(a, fsets[i]) {
			is[name] = i
		}
	}

	var impls []*interfaceGenerator
	var ts []*testGenerator
	// Set of Marshallable types referenced by generated code.
	ms := make(map[string]struct{})
	for i, a := range asts {
		// Collect type declarations marked for code generation and generate
		// Marshallable interfaces.
		var sortedTypes []*marshallableType
		for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
			sortedTypes = append(sortedTypes, t)
		}
		sort.Slice(sortedTypes, func(x, y int) bool {
			// Sort by type name, which should be unique within a package.
			return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String()
		})
		for _, t := range sortedTypes {
			impl := g.generateOne(t, fsets[i])
			// Collect Marshallable types referenced by the generated code.
			for ref := range impl.ms {
				ms[ref] = struct{}{}
			}
			impls = append(impls, impl)
			// Collect imports referenced by the generated code and add them to
			// the list of imports we need to copy to the generated code.
			for name := range impl.is {
				if !g.imports.markUsed(name) {
					panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
				}
			}
			ts = append(ts, g.generateOneTestSuite(t))
		}
	}

	// Write output file header. These include things like package name and
	// import statements.
	if err := g.writeHeader(); err != nil {
		return err
	}

	// Write type checks for referenced marshallable types to output file.
	if err := g.writeTypeChecks(ms); err != nil {
		return err
	}

	// Write generated interfaces to output file.
	for _, i := range impls {
		if err := i.write(g.output); err != nil {
			return err
		}
	}

	// Write generated tests to test file.
	return g.writeTests(ts)
}

// writeTests outputs tests for the generated interface implementations to a go
// source file.
func (g *Generator) writeTests(ts []*testGenerator) error {
	var b sourceBuffer

	// Write the unconditional test file. This file is always compiled,
	// regardless of what build tags were specified on the original input
	// files. We use this file to guarantee we never end up with an empty test
	// file, as that causes the build to fail with "no tests/benchmarks/examples
	// found".
	//
	// There's no easy way to determine ahead of time if we'll end up with an
	// empty build file since build constraints can arbitrarily cause some of
	// the original types to be not defined. We also have no way to tell bazel
	// to omit the entire test suite since the output files are already defined
	// before go-marshal is called.
	b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
	b.emit("package %s\n\n", g.pkg)
	b.emit("func Example() {\n")
	b.inIndent(func() {
		b.emit("// This example is intentionally empty, and ensures this package contains at\n")
		b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n")
		b.emit("// input package is marked marshallable, but emitting no testable entities \n")
		b.emit("// results in a build failure.\n")
	})
	b.emit("}\n")
	if err := b.write(g.outputTestUC); err != nil {
		return err
	}

	// Now generate the real test file that contains the real types we
	// processed. These need to be conditionally compiled according to the build
	// tags, as the original types may not be defined under all build
	// configurations.

	b.reset()
	b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")

	// Emit build tags.
	if t := tags.Aggregate(g.inputs); len(t) > 0 {
		b.emit(strings.Join(t.Lines(), "\n"))
		b.emit("\n\n")
	}

	b.emit("package %s\n\n", g.pkg)
	if err := b.write(g.outputTest); err != nil {
		return err
	}

	// Collect and write test import statements.
	imports := newImportTable()
	for _, t := range ts {
		imports.merge(t.imports)
	}

	if err := imports.write(g.outputTest); err != nil {
		return err
	}

	// Write test functions.
	for _, t := range ts {
		if err := t.write(g.outputTest); err != nil {
			return err
		}
	}
	return nil
}