diff --git a/builder.go b/builder.go index efac14c..004b581 100644 --- a/builder.go +++ b/builder.go @@ -3,10 +3,12 @@ package godog import ( "bytes" "go/ast" + "go/build" "go/format" "go/parser" "go/token" "os" + "path" "path/filepath" "strconv" "strings" @@ -148,6 +150,37 @@ func (b *builder) registerContexts(f *ast.File) { } } +type visitFn func(node ast.Node) ast.Visitor + +func (fn visitFn) Visit(node ast.Node) ast.Visitor { + return fn(node) +} + +func (b *builder) usedPackages(f *ast.File) []string { + var refs []string + var visitor visitFn + visitor = visitFn(func(node ast.Node) ast.Visitor { + if node == nil { + return visitor + } + switch v := node.(type) { + case *ast.SelectorExpr: + xident, ok := v.X.(*ast.Ident) + if !ok { + break + } + if xident.Obj != nil { + // if the parser can resolve it, it's not a package ref + break + } + refs = append(refs, xident.Name) + } + return visitor + }) + ast.Walk(visitor, f) + return refs +} + func (b *builder) merge() ([]byte, error) { var buf bytes.Buffer if err := b.tpl.Execute(&buf, b); err != nil { @@ -158,7 +191,6 @@ func (b *builder) merge() ([]byte, error) { if err != nil { return nil, err } - // b.imports(f) b.deleteImports(f) b.files["main.go"] = f @@ -181,13 +213,26 @@ func (b *builder) merge() ([]byte, error) { return nil, err } + used := b.usedPackages(ret) + isUsed := func(p string) bool { + for _, ref := range used { + if p == ref { + return true + } + } + return false + } for _, spec := range b.imports { var name string + ipath := strings.Trim(spec.Path.Value, `\"`) + check := importPathToName(ipath) if spec.Name != nil { name = spec.Name.Name + check = spec.Name.Name + } + if isUsed(check) { + addImport(b.fset, ret, name, ipath) } - ipath, _ := strconv.Unquote(spec.Path.Value) - addImport(b.fset, ret, name, ipath) } buf.Reset() @@ -358,3 +403,20 @@ func importPath(s *ast.ImportSpec) string { } return "" } + +var importPathToName = importPathToNameGoPath + +// importPathToNameBasic assumes the package name is the base of import path. +func importPathToNameBasic(importPath string) (packageName string) { + return path.Base(importPath) +} + +// importPathToNameGoPath finds out the actual package name, as declared in its .go files. +// If there's a problem, it falls back to using importPathToNameBasic. +func importPathToNameGoPath(importPath string) (packageName string) { + if buildPkg, err := build.Import(importPath, "", 0); err == nil { + return buildPkg.Name + } else { + return importPathToNameBasic(importPath) + } +}