import (
"bytes"
+ "flag"
+ "fmt"
"go/ast"
"go/format"
"go/parser"
var fset = token.NewFileSet()
func main() {
- for filename, action := range filemap {
- // parse src
- srcFilename := filepath.FromSlash(runtime.GOROOT() + "/src/" + srcDir + "/" + filename)
- file, err := parser.ParseFile(fset, srcFilename, nil, parser.ParseComments)
- if err != nil {
- log.Fatal(err)
+ flag.Parse()
+
+ // process provided filenames, if any
+ if flag.NArg() > 0 {
+ for _, filename := range flag.Args() {
+ fmt.Println("generating", filename)
+ generate(filename, filemap[filename])
}
+ return
+ }
- // fix package name
- file.Name.Name = strings.ReplaceAll(file.Name.Name, "types2", "types")
+ // otherwise process per filemap below
+ for filename, action := range filemap {
+ generate(filename, action)
+ }
+}
- // rewrite AST as needed
- if action != nil {
- action(file)
- }
+func generate(filename string, action action) {
+ // parse src
+ srcFilename := filepath.FromSlash(runtime.GOROOT() + "/src/" + srcDir + "/" + filename)
+ file, err := parser.ParseFile(fset, srcFilename, nil, parser.ParseComments)
+ if err != nil {
+ log.Fatal(err)
+ }
- // format AST
- var buf bytes.Buffer
- buf.WriteString("// Code generated by \"go run generator.go\"; DO NOT EDIT.\n\n")
- if err := format.Node(&buf, fset, file); err != nil {
- log.Fatal(err)
- }
+ // fix package name
+ file.Name.Name = strings.ReplaceAll(file.Name.Name, "types2", "types")
- // write dst
- dstFilename := filepath.FromSlash(runtime.GOROOT() + "/src/" + dstDir + "/" + filename)
- if err := os.WriteFile(dstFilename, buf.Bytes(), 0o644); err != nil {
- log.Fatal(err)
- }
+ // rewrite AST as needed
+ if action != nil {
+ action(file)
+ }
+
+ // format AST
+ var buf bytes.Buffer
+ buf.WriteString("// Code generated by \"go run generator.go\"; DO NOT EDIT.\n\n")
+ if err := format.Node(&buf, fset, file); err != nil {
+ log.Fatal(err)
+ }
+
+ // write dst
+ dstFilename := filepath.FromSlash(runtime.GOROOT() + "/src/" + dstDir + "/" + filename)
+ if err := os.WriteFile(dstFilename, buf.Bytes(), 0o644); err != nil {
+ log.Fatal(err)
}
}