diff --git a/builder.go b/builder.go index 915e3f4..070acad 100644 --- a/builder.go +++ b/builder.go @@ -10,17 +10,17 @@ import ( "os" "path" "path/filepath" - "strconv" "strings" "text/template" ) type builder struct { - files map[string]*ast.File - fset *token.FileSet - Contexts []string - Internal bool - tpl *template.Template + files map[string]*ast.File + fset *token.FileSet + Contexts []string + Internal bool + SuiteName string + tpl *template.Template imports []*ast.ImportSpec } @@ -46,22 +46,26 @@ func newBuilderSkel() *builder { files: make(map[string]*ast.File), fset: token.NewFileSet(), tpl: template.Must(template.New("main").Parse(`package main -{{ if not .Internal }}import ( - "github.com/DATA-DOG/godog" -){{ end }} +import ( +{{ if not .Internal }} "github.com/DATA-DOG/godog"{{ end }} + "os" + "testing" +) -func main() { +const GodogSuiteName = "{{ .SuiteName }}" - {{ if not .Internal }}godog.{{ end }}Run(func (suite *{{ if not .Internal }}godog.{{ end }}Suite) { +func TestMain(m *testing.M) { + status := {{ if not .Internal }}godog.{{ end }}Run(func (suite *{{ if not .Internal }}godog.{{ end }}Suite) { {{range .Contexts}} {{ . }}(suite) {{end}} }) + os.Exit(status) }`)), } } -func newBuilder(buildPath string) (*builder, error) { +func doBuild(buildPath, dir string) error { b := newBuilderSkel() err := filepath.Walk(buildPath, func(path string, file os.FileInfo, err error) error { if file.IsDir() && file.Name() != "." { @@ -72,23 +76,89 @@ func newBuilder(buildPath string) (*builder, error) { if err != nil { return err } - b.register(f, path) + b.register(f, file.Name()) } return err }) - return b, err + if err != nil { + return err + } + + var buf bytes.Buffer + if err := b.tpl.Execute(&buf, b); err != nil { + return err + } + + f, err := parser.ParseFile(b.fset, "", &buf, 0) + if err != nil { + return err + } + + b.files["godog_test.go"] = f + + os.Mkdir(dir, 0755) + for name, node := range b.files { + f, err := os.Create(filepath.Join(dir, name)) + if err != nil { + return err + } + if err := format.Node(f, b.fset, node); err != nil { + return err + } + } + + return nil } -func (b *builder) register(f *ast.File, path string) { +func (b *builder) register(f *ast.File, name string) { // mark godog package as internal if f.Name.Name == "godog" && !b.Internal { b.Internal = true } + b.SuiteName = f.Name.Name b.deleteMainFunc(f) + f.Name.Name = "main" b.registerContexts(f) - b.deleteImports(f) - b.files[path] = f + b.files[name] = f +} + +func (b *builder) removeUnusedImports(f *ast.File) { + used := b.usedPackages(f) + isUsed := func(p string) bool { + for _, ref := range used { + if p == ref { + return true + } + } + return p == "_" + } + var decls []ast.Decl + for _, d := range f.Decls { + gen, ok := d.(*ast.GenDecl) + if ok && gen.Tok == token.IMPORT { + var specs []ast.Spec + for _, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + ipath := strings.Trim(impspec.Path.Value, `\"`) + check := importPathToName(ipath) + if impspec.Name != nil { + check = impspec.Name.Name + } + + if isUsed(check) { + specs = append(specs, spec) + } + } + + if len(specs) == 0 { + continue + } + gen.Specs = specs + } + decls = append(decls, d) + } + f.Decls = decls } func (b *builder) deleteImports(f *ast.File) { @@ -111,17 +181,24 @@ func (b *builder) deleteImports(f *ast.File) { func (b *builder) deleteMainFunc(f *ast.File) { var decls []ast.Decl + var hadMain bool for _, d := range f.Decls { fun, ok := d.(*ast.FuncDecl) if !ok { decls = append(decls, d) continue } - if fun.Name.Name != "main" { + if fun.Name.Name != "TestMain" { decls = append(decls, fun) + } else { + hadMain = true } } f.Decls = decls + + if hadMain { + b.removeUnusedImports(f) + } } func (b *builder) registerContexts(f *ast.File) { @@ -181,68 +258,6 @@ func (b *builder) usedPackages(f *ast.File) []string { return refs } -func (b *builder) merge() ([]byte, error) { - var buf bytes.Buffer - if err := b.tpl.Execute(&buf, b); err != nil { - return nil, err - } - - f, err := parser.ParseFile(b.fset, "", &buf, 0) - if err != nil { - return nil, err - } - b.deleteImports(f) - b.files["main.go"] = f - - pkg, _ := ast.NewPackage(b.fset, b.files, nil, nil) - pkg.Name = "main" - - ret, err := ast.MergePackageFiles(pkg, 0), nil - if err != nil { - return nil, err - } - - // @TODO: we reread the file, probably something goes wrong with position - buf.Reset() - if err = format.Node(&buf, b.fset, ret); err != nil { - return nil, err - } - - ret, err = parser.ParseFile(b.fset, "", buf.Bytes(), 0) - if err != nil { - return nil, err - } - - used := b.usedPackages(ret) - isUsed := func(p string) bool { - for _, ref := range used { - if p == ref { - return true - } - } - return p == "_" - } - for _, spec := range b.imports { - var name string - ipath := strings.Trim(spec.Path.Value, `\"`) - check := importPathToName(ipath) - if spec.Name != nil { - name = spec.Name.Name - check = spec.Name.Name - } - if isUsed(check) { - addImport(b.fset, ret, name, ipath) - } - } - - buf.Reset() - if err := format.Node(&buf, b.fset, ret); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - // Build creates a runnable Godog executable file // from current package source and test source files. // @@ -253,165 +268,15 @@ func (b *builder) merge() ([]byte, error) { // Currently, to manage imports we use "golang.org/x/tools/imports" // package, but that may be replaced in order to have // no external dependencies -func Build() ([]byte, error) { - b, err := newBuilder(".") - if err != nil { - return nil, err - } - - return b.merge() +func Build(dir string) error { + return doBuild(".", dir) } -// taken from https://github.com/golang/tools/blob/master/go/ast/astutil/imports.go#L17 -func addImport(fset *token.FileSet, f *ast.File, name, ipath string) { - newImport := &ast.ImportSpec{ - Path: &ast.BasicLit{ - Kind: token.STRING, - Value: strconv.Quote(ipath), - }, - } - if name != "" { - newImport.Name = &ast.Ident{Name: name} - } - - // Find an import decl to add to. - // The goal is to find an existing import - // whose import path has the longest shared - // prefix with ipath. - var ( - bestMatch = -1 // length of longest shared prefix - lastImport = -1 // index in f.Decls of the file's final import decl - impDecl *ast.GenDecl // import decl containing the best match - impIndex = -1 // spec index in impDecl containing the best match - ) - for i, decl := range f.Decls { - gen, ok := decl.(*ast.GenDecl) - if ok && gen.Tok == token.IMPORT { - lastImport = i - // Do not add to import "C", to avoid disrupting the - // association with its doc comment, breaking cgo. - if declImports(gen, "C") { - continue - } - - // Match an empty import decl if that's all that is available. - if len(gen.Specs) == 0 && bestMatch == -1 { - impDecl = gen - } - - // Compute longest shared prefix with imports in this group. - for j, spec := range gen.Specs { - impspec := spec.(*ast.ImportSpec) - n := matchLen(importPath(impspec), ipath) - if n > bestMatch { - bestMatch = n - impDecl = gen - impIndex = j - } - } - } - } - - // If no import decl found, add one after the last import. - if impDecl == nil { - impDecl = &ast.GenDecl{ - Tok: token.IMPORT, - } - if lastImport >= 0 { - impDecl.TokPos = f.Decls[lastImport].End() - } else { - // There are no existing imports. - // Our new import goes after the package declaration and after - // the comment, if any, that starts on the same line as the - // package declaration. - impDecl.TokPos = f.Package - - file := fset.File(f.Package) - pkgLine := file.Line(f.Package) - for _, c := range f.Comments { - if file.Line(c.Pos()) > pkgLine { - break - } - impDecl.TokPos = c.End() - } - } - f.Decls = append(f.Decls, nil) - copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:]) - f.Decls[lastImport+1] = impDecl - } - - // Insert new import at insertAt. - insertAt := 0 - if impIndex >= 0 { - // insert after the found import - insertAt = impIndex + 1 - } - impDecl.Specs = append(impDecl.Specs, nil) - copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:]) - impDecl.Specs[insertAt] = newImport - pos := impDecl.Pos() - if insertAt > 0 { - // Assign same position as the previous import, - // so that the sorter sees it as being in the same block. - pos = impDecl.Specs[insertAt-1].Pos() - } - if newImport.Name != nil { - newImport.Name.NamePos = pos - } - newImport.Path.ValuePos = pos - newImport.EndPos = pos - - // Clean up parens. impDecl contains at least one spec. - if len(impDecl.Specs) == 1 { - // Remove unneeded parens. - impDecl.Lparen = token.NoPos - } else if !impDecl.Lparen.IsValid() { - // impDecl needs parens added. - impDecl.Lparen = impDecl.Specs[0].Pos() - } - - f.Imports = append(f.Imports, newImport) -} - -func declImports(gen *ast.GenDecl, path string) bool { - if gen.Tok != token.IMPORT { - return false - } - for _, spec := range gen.Specs { - impspec := spec.(*ast.ImportSpec) - if importPath(impspec) == path { - return true - } - } - return false -} - -func matchLen(x, y string) int { - n := 0 - for i := 0; i < len(x) && i < len(y) && x[i] == y[i]; i++ { - if x[i] == '/' { - n++ - } - } - return n -} - -func importPath(s *ast.ImportSpec) string { - return strings.Trim(s.Path.Value, `\"`) -} - -var importPathToName = importPathToNameGoPath - -// importPathToNameBasic assumes the package name is the base of import path. -func importPathToNameBasic(importPath string) (packageName string) { - return path.Base(importPath) -} - -// importPathToNameGoPath finds out the actual package name, as declared in its .go files. +// importPathToName finds out the actual package name, as declared in its .go files. // If there's a problem, it falls back to using importPathToNameBasic. -func importPathToNameGoPath(importPath string) (packageName string) { +func importPathToName(importPath string) (packageName string) { if buildPkg, err := build.Import(importPath, "", 0); err == nil { return buildPkg.Name } - return importPathToNameBasic(importPath) + return path.Base(importPath) } diff --git a/builder_test.go b/builder_test.go index 6338874..5d7c8c5 100644 --- a/builder_test.go +++ b/builder_test.go @@ -1,228 +1 @@ package godog - -import ( - "fmt" - "go/parser" - "go/token" - "runtime" - "strings" - "testing" -) - -var builderMainFile = ` -package main -import "fmt" -func main() { - fmt.Println("hello") -}` - -var builderPackAliases = ` -package main -import ( - a "fmt" - b "fmt" -) -func Tester() { - a.Println("a") - b.Println("b") -}` - -var builderAnonymousImport = ` -package main -import ( - _ "github.com/go-sql-driver/mysql" -) -` - -var builderContextSrc = ` -package main -import ( - "github.com/DATA-DOG/godog" -) - -func myContext(s *godog.Suite) { - -} -` - -var builderLibrarySrc = ` -package lib -import "fmt" -func test() { - fmt.Println("hello") -} -` - -var builderInternalPackageSrc = ` -package godog -import "fmt" -func test() { - fmt.Println("hello") -} -` - -func (b *builder) registerMulti(contents []string) error { - for i, c := range contents { - f, err := parser.ParseFile(token.NewFileSet(), "", []byte(c), 0) - if err != nil { - return err - } - b.register(f, fmt.Sprintf("path%d", i)) - } - return nil -} - -func (b *builder) cleanSpacing(src string) string { - var lines []string - for _, ln := range strings.Split(src, "\n") { - if ln == "" { - continue - } - lines = append(lines, strings.TrimSpace(ln)) - } - return strings.Join(lines, "\n") -} - -func TestUsualSourceFileMerge(t *testing.T) { - if strings.HasPrefix(runtime.Version(), "go1.1") { - t.Skip("skipping this test for go1.1") - } - b := newBuilderSkel() - err := b.registerMulti([]string{ - builderMainFile, builderPackAliases, builderAnonymousImport, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - data, err := b.merge() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := `package main - -import ( - a "fmt" - b "fmt" - "github.com/DATA-DOG/godog" - _ "github.com/go-sql-driver/mysql" -) - -func main() { - godog.Run(func(suite *godog.Suite) { - - }) -} -func Tester() { - a.Println("a") - b.Println("b") -}` - - actual := string(data) - if b.cleanSpacing(expected) != b.cleanSpacing(actual) { - t.Fatalf("expected output does not match: %s", actual) - } -} - -func TestShouldCallContextOnMerged(t *testing.T) { - b := newBuilderSkel() - err := b.registerMulti([]string{ - builderMainFile, builderContextSrc, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - data, err := b.merge() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := `package main - -import "github.com/DATA-DOG/godog" - -func main() { - godog.Run(func(suite *godog.Suite) { - myContext(suite) - }) -} - -func myContext(s *godog.Suite) { -}` - - actual := string(data) - // log.Println("actual:", actual) - // log.Println("expected:", expected) - if b.cleanSpacing(expected) != b.cleanSpacing(actual) { - t.Fatalf("expected output does not match: %s", actual) - } -} - -func TestBuildLibraryPackage(t *testing.T) { - b := newBuilderSkel() - err := b.registerMulti([]string{ - builderLibrarySrc, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - data, err := b.merge() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := `package main -import ( - "fmt" - "github.com/DATA-DOG/godog" -) - -func main() { - godog.Run(func(suite *godog.Suite) { - - }) -} - -func test() { - fmt.Println( - "hello", - ) -}` - - actual := string(data) - if b.cleanSpacing(expected) != b.cleanSpacing(actual) { - t.Fatalf("expected output does not match: %s", actual) - } -} - -func TestBuildInternalPackage(t *testing.T) { - b := newBuilderSkel() - err := b.registerMulti([]string{ - builderInternalPackageSrc, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - data, err := b.merge() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - expected := `package main -import "fmt" - -func main() { - Run(func(suite *Suite) { - - }) -} - -func test() { - fmt.Println("hello") -}` - - actual := string(data) - if b.cleanSpacing(expected) != b.cleanSpacing(actual) { - t.Fatalf("expected output does not match: %s", actual) - } -} diff --git a/cmd/godog/main.go b/cmd/godog/main.go index 94032e2..eb15bee 100644 --- a/cmd/godog/main.go +++ b/cmd/godog/main.go @@ -3,8 +3,10 @@ package main import ( "fmt" "io" + "log" "os" "os/exec" + "path/filepath" "regexp" "strconv" "syscall" @@ -23,26 +25,29 @@ func buildAndRun() (int, error) { stdout := ansicolor.NewAnsiColorWriter(os.Stdout) stderr := ansicolor.NewAnsiColorWriter(statusOutputFilter(os.Stderr)) - builtFile := fmt.Sprintf("%s/%dgodog.go", os.TempDir(), time.Now().UnixNano()) - - buf, err := godog.Build() + dir := fmt.Sprintf(filepath.Join("%s", "%dgodogs"), os.TempDir(), time.Now().UnixNano()) + err := godog.Build(dir) if err != nil { - return status, err + return 1, err } - w, err := os.Create(builtFile) + defer os.RemoveAll(dir) + + wd, err := os.Getwd() if err != nil { - return status, err + return 1, err } - defer os.Remove(builtFile) + bin := filepath.Join(wd, "godog.test") - if _, err = w.Write(buf); err != nil { - w.Close() - return status, err + cmdb := exec.Command("go", "test", "-c", "-o", bin) + cmdb.Dir = dir + if dat, err := cmdb.CombinedOutput(); err != nil { + log.Println(string(dat)) + return 1, err } - w.Close() + defer os.Remove(bin) - cmd := exec.Command("go", append([]string{"run", builtFile}, os.Args[1:]...)...) + cmd := exec.Command(bin, os.Args[1:]...) cmd.Stdout = stdout cmd.Stderr = stderr diff --git a/run.go b/run.go index 9b6e404..b89855f 100644 --- a/run.go +++ b/run.go @@ -52,7 +52,7 @@ func (r *runner) run() (failed bool) { // // contextInitializer must be able to register // the step definitions and event handlers. -func Run(contextInitializer func(suite *Suite)) { +func Run(contextInitializer func(suite *Suite)) int { var vers, defs, sof bool var tags, format string var concurrency int @@ -63,12 +63,12 @@ func Run(contextInitializer func(suite *Suite)) { switch { case vers: fmt.Println(cl("Godog", green) + " version is " + cl(Version, yellow)) - return + return 0 case defs: s := &Suite{} contextInitializer(s) s.printStepDefinitions() - return + return 0 } paths := flagSet.Args() @@ -99,6 +99,7 @@ func Run(contextInitializer func(suite *Suite)) { } if failed := r.run(); failed { - os.Exit(1) + return 1 } + return 0 } diff --git a/suite_test.go b/suite_test.go index ba37e6e..c8fb7d1 100644 --- a/suite_test.go +++ b/suite_test.go @@ -1,13 +1,21 @@ package godog import ( + "flag" "fmt" + "os" "strconv" "strings" + "testing" "gopkg.in/cucumber/gherkin-go.v3" ) +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(m.Run()) +} + func SuiteContext(s *Suite) { c := &suiteContext{}