diff --git a/Makefile b/Makefile index 8d3ca881..b23cc069 100644 --- a/Makefile +++ b/Makefile @@ -265,6 +265,7 @@ TEST_PACKAGES_FAST = \ crypto/sha256 \ crypto/sha512 \ debug/macho \ + embed/internal/embedtest \ encoding \ encoding/ascii85 \ encoding/base32 \ diff --git a/builder/build.go b/builder/build.go index d68f36a3..cf9f5777 100644 --- a/builder/build.go +++ b/builder/build.go @@ -4,6 +4,7 @@ package builder import ( + "crypto/sha256" "crypto/sha512" "debug/elf" "encoding/binary" @@ -80,6 +81,7 @@ type packageAction struct { Config *compiler.Config CFlags []string FileHashes map[string]string // hash of every file that's part of the package + EmbeddedFiles map[string]string // hash of all the //go:embed files in the package Imports map[string]string // map from imported package to action ID hash OptLevel int // LLVM optimization level (0-3) SizeLevel int // LLVM optimization for size level (0-2) @@ -225,6 +227,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil config.Options.GlobalValues["runtime"]["buildVersion"] = version } + var embedFileObjects []*compileJob for _, pkg := range lprogram.Sorted() { pkg := pkg // necessary to avoid a race condition @@ -234,6 +237,47 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil } sort.Strings(undefinedGlobals) + // Make compile jobs to load files to be embedded in the output binary. + var actionIDDependencies []*compileJob + allFiles := map[string][]*loader.EmbedFile{} + for _, files := range pkg.EmbedGlobals { + for _, file := range files { + allFiles[file.Name] = append(allFiles[file.Name], file) + } + } + for name, files := range allFiles { + name := name + files := files + job := &compileJob{ + description: "make object file for " + name, + run: func(job *compileJob) error { + // Read the file contents in memory. + path := filepath.Join(pkg.Dir, name) + data, err := os.ReadFile(path) + if err != nil { + return err + } + + // Hash the file. + sum := sha256.Sum256(data) + hexSum := hex.EncodeToString(sum[:16]) + + for _, file := range files { + file.Size = uint64(len(data)) + file.Hash = hexSum + if file.NeedsData { + file.Data = data + } + } + + job.result, err = createEmbedObjectFile(string(data), hexSum, name, pkg.OriginalDir(), dir, compilerConfig) + return err + }, + } + actionIDDependencies = append(actionIDDependencies, job) + embedFileObjects = append(embedFileObjects, job) + } + // Action ID jobs need to know the action ID of all the jobs the package // imports. var importedPackages []*compileJob @@ -243,6 +287,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil return fmt.Errorf("package %s imports %s but couldn't find dependency", pkg.ImportPath, imported.Path()) } importedPackages = append(importedPackages, job) + actionIDDependencies = append(actionIDDependencies, job) } // Create a job that will calculate the action ID for a package compile @@ -250,7 +295,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil // package. packageActionIDJob := &compileJob{ description: "calculate cache key for package " + pkg.ImportPath, - dependencies: importedPackages, + dependencies: actionIDDependencies, run: func(job *compileJob) error { // Create a cache key: a hash from the action ID below that contains all // the parameters for the build. @@ -262,6 +307,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil Config: compilerConfig, CFlags: pkg.CFlags, FileHashes: make(map[string]string, len(pkg.FileHashes)), + EmbeddedFiles: make(map[string]string, len(allFiles)), Imports: make(map[string]string, len(pkg.Pkg.Imports())), OptLevel: optLevel, SizeLevel: sizeLevel, @@ -270,6 +316,9 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil for filePath, hash := range pkg.FileHashes { actionID.FileHashes[filePath] = hex.EncodeToString(hash) } + for name, files := range allFiles { + actionID.EmbeddedFiles[name] = files[0].Hash + } for i, imported := range pkg.Pkg.Imports() { actionID.Imports[imported.Path()] = importedPackages[i].result } @@ -668,6 +717,9 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil // Add libc dependencies, if they exist. linkerDependencies = append(linkerDependencies, libcDependencies...) + // Add embedded files. + linkerDependencies = append(linkerDependencies, embedFileObjects...) + // Strip debug information with -no-debug. if !config.Debug() { for _, tag := range config.BuildTags() { @@ -920,6 +972,112 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil }) } +// createEmbedObjectFile creates a new object file with the given contents, for +// the embed package. +func createEmbedObjectFile(data, hexSum, sourceFile, sourceDir, tmpdir string, compilerConfig *compiler.Config) (string, error) { + // TODO: this works for small files, but can be a problem for larger files. + // For larger files, it seems more appropriate to generate the object file + // manually without going through LLVM. + // On the other hand, generating DWARF like we do here can be difficult + // without assistance from LLVM. + + // Create new LLVM module just for this file. + ctx := llvm.NewContext() + defer ctx.Dispose() + mod := ctx.NewModule("data") + defer mod.Dispose() + + // Create data global. + value := ctx.ConstString(data, false) + globalName := "embed/file_" + hexSum + global := llvm.AddGlobal(mod, value.Type(), globalName) + global.SetInitializer(value) + global.SetLinkage(llvm.LinkOnceODRLinkage) + global.SetGlobalConstant(true) + global.SetUnnamedAddr(true) + global.SetAlignment(1) + if compilerConfig.GOOS != "darwin" { + // MachO doesn't support COMDATs, while COFF requires it (to avoid + // "duplicate symbol" errors). ELF works either way. + // Therefore, only use a COMDAT on non-MachO systems (aka non-MacOS). + global.SetComdat(mod.Comdat(globalName)) + } + + // Add DWARF debug information to this global, so that it is + // correctly counted when compiling with the -size= flag. + dibuilder := llvm.NewDIBuilder(mod) + dibuilder.CreateCompileUnit(llvm.DICompileUnit{ + Language: 0xb, // DW_LANG_C99 (0xc, off-by-one?) + File: sourceFile, + Dir: sourceDir, + Producer: "TinyGo", + Optimized: false, + }) + ditype := dibuilder.CreateArrayType(llvm.DIArrayType{ + SizeInBits: uint64(len(data)) * 8, + AlignInBits: 8, + ElementType: dibuilder.CreateBasicType(llvm.DIBasicType{ + Name: "byte", + SizeInBits: 8, + Encoding: llvm.DW_ATE_unsigned_char, + }), + Subscripts: []llvm.DISubrange{ + { + Lo: 0, + Count: int64(len(data)), + }, + }, + }) + difile := dibuilder.CreateFile(sourceFile, sourceDir) + diglobalexpr := dibuilder.CreateGlobalVariableExpression(difile, llvm.DIGlobalVariableExpression{ + Name: globalName, + File: difile, + Line: 1, + Type: ditype, + Expr: dibuilder.CreateExpression(nil), + AlignInBits: 8, + }) + global.AddMetadata(0, diglobalexpr) + mod.AddNamedMetadataOperand("llvm.module.flags", + ctx.MDNode([]llvm.Metadata{ + llvm.ConstInt(ctx.Int32Type(), 2, false).ConstantAsMetadata(), // Warning on mismatch + ctx.MDString("Debug Info Version"), + llvm.ConstInt(ctx.Int32Type(), 3, false).ConstantAsMetadata(), + }), + ) + mod.AddNamedMetadataOperand("llvm.module.flags", + ctx.MDNode([]llvm.Metadata{ + llvm.ConstInt(ctx.Int32Type(), 7, false).ConstantAsMetadata(), // Max on mismatch + ctx.MDString("Dwarf Version"), + llvm.ConstInt(ctx.Int32Type(), 4, false).ConstantAsMetadata(), + }), + ) + dibuilder.Finalize() + dibuilder.Destroy() + + // Write this LLVM module out as an object file. + machine, err := compiler.NewTargetMachine(compilerConfig) + if err != nil { + return "", err + } + defer machine.Dispose() + outfile, err := os.CreateTemp(tmpdir, "embed-"+hexSum+"-*.o") + if err != nil { + return "", err + } + defer outfile.Close() + buf, err := machine.EmitToMemoryBuffer(mod, llvm.ObjectFile) + if err != nil { + return "", err + } + defer buf.Dispose() + _, err = outfile.Write(buf.Bytes()) + if err != nil { + return "", err + } + return outfile.Name(), outfile.Close() +} + // optimizeProgram runs a series of optimizations and transformations that are // needed to convert a program to its final form. Some transformations are not // optional and must be run as the compiler expects them to run. diff --git a/builder/sizes.go b/builder/sizes.go index d9e430f1..ef4389f6 100644 --- a/builder/sizes.go +++ b/builder/sizes.go @@ -117,7 +117,7 @@ var ( // alloc: heap allocations during init interpretation // pack: data created when storing a constant in an interface for example // string: buffer behind strings - packageSymbolRegexp = regexp.MustCompile(`\$(alloc|pack|string)(\.[0-9]+)?$`) + packageSymbolRegexp = regexp.MustCompile(`\$(alloc|embedfsfiles|embedfsslice|embedslice|pack|string)(\.[0-9]+)?$`) // Reflect sidetables. Created by the reflect lowering pass. // See src/reflect/sidetables.go. diff --git a/compiler/compiler.go b/compiler/compiler.go index f200b985..cb5e0795 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -9,6 +9,7 @@ import ( "go/token" "go/types" "math/bits" + "path" "path/filepath" "sort" "strconv" @@ -76,6 +77,7 @@ type compilerContext struct { program *ssa.Program diagnostics []error astComments map[string]*ast.CommentGroup + embedGlobals map[string][]*loader.EmbedFile pkg *types.Package packageDir string // directory for this package runtimePkg *types.Package @@ -250,6 +252,7 @@ func Sizes(machine llvm.TargetMachine) types.Sizes { func CompilePackage(moduleName string, pkg *loader.Package, ssaPkg *ssa.Package, machine llvm.TargetMachine, config *Config, dumpSSA bool) (llvm.Module, []error) { c := newCompilerContext(moduleName, machine, config, dumpSSA) c.packageDir = pkg.OriginalDir() + c.embedGlobals = pkg.EmbedGlobals c.pkg = pkg.Pkg c.runtimePkg = ssaPkg.Prog.ImportedPackage("runtime").Pkg c.program = ssaPkg.Prog @@ -815,7 +818,9 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package // Global variable. info := c.getGlobalInfo(member) global := c.getGlobal(member) - if !info.extern { + if files, ok := c.embedGlobals[member.Name()]; ok { + c.createEmbedGlobal(member, global, files) + } else if !info.extern { global.SetInitializer(llvm.ConstNull(global.Type().ElementType())) global.SetVisibility(llvm.HiddenVisibility) if info.section != "" { @@ -849,6 +854,150 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package } } +// createEmbedGlobal creates an initializer for a //go:embed global variable. +func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Value, files []*loader.EmbedFile) { + switch typ := member.Type().(*types.Pointer).Elem().Underlying().(type) { + case *types.Basic: + // String type. + if typ.Kind() != types.String { + // This is checked at the AST level, so should be unreachable. + panic("expected a string type") + } + if len(files) != 1 { + c.addError(member.Pos(), fmt.Sprintf("//go:embed for a string should be given exactly one file, got %d", len(files))) + return + } + strObj := c.getEmbedFileString(files[0]) + global.SetInitializer(strObj) + global.SetVisibility(llvm.HiddenVisibility) + + case *types.Slice: + if typ.Elem().Underlying().(*types.Basic).Kind() != types.Byte { + // This is checked at the AST level, so should be unreachable. + panic("expected a byte slice") + } + if len(files) != 1 { + c.addError(member.Pos(), fmt.Sprintf("//go:embed for a string should be given exactly one file, got %d", len(files))) + return + } + file := files[0] + bufferValue := c.ctx.ConstString(string(file.Data), false) + bufferGlobal := llvm.AddGlobal(c.mod, bufferValue.Type(), c.pkg.Path()+"$embedslice") + bufferGlobal.SetInitializer(bufferValue) + bufferGlobal.SetLinkage(llvm.InternalLinkage) + bufferGlobal.SetAlignment(1) + slicePtr := llvm.ConstInBoundsGEP(bufferGlobal, []llvm.Value{ + llvm.ConstInt(c.uintptrType, 0, false), + llvm.ConstInt(c.uintptrType, 0, false), + }) + sliceLen := llvm.ConstInt(c.uintptrType, file.Size, false) + sliceObj := c.ctx.ConstStruct([]llvm.Value{slicePtr, sliceLen, sliceLen}, false) + global.SetInitializer(sliceObj) + global.SetVisibility(llvm.HiddenVisibility) + + case *types.Struct: + // Assume this is an embed.FS struct: + // https://cs.opensource.google/go/go/+/refs/tags/go1.18.2:src/embed/embed.go;l=148 + // It looks like this: + // type FS struct { + // files *file + // } + + // Make a slice of the files, as they will appear in the binary. They + // are sorted in a special way to allow for binary searches, see + // src/embed/embed.go for details. + dirset := map[string]struct{}{} + var allFiles []*loader.EmbedFile + for _, file := range files { + allFiles = append(allFiles, file) + dirname := file.Name + for { + dirname, _ = path.Split(path.Clean(dirname)) + if dirname == "" { + break + } + if _, ok := dirset[dirname]; ok { + break + } + dirset[dirname] = struct{}{} + allFiles = append(allFiles, &loader.EmbedFile{ + Name: dirname, + }) + } + } + sort.Slice(allFiles, func(i, j int) bool { + dir1, name1 := path.Split(path.Clean(allFiles[i].Name)) + dir2, name2 := path.Split(path.Clean(allFiles[j].Name)) + if dir1 != dir2 { + return dir1 < dir2 + } + return name1 < name2 + }) + + // Make the backing array for the []files slice. This is a LLVM global. + embedFileStructType := c.getLLVMType(typ.Field(0).Type().(*types.Pointer).Elem().(*types.Slice).Elem()) + var fileStructs []llvm.Value + for _, file := range allFiles { + fileStruct := llvm.ConstNull(embedFileStructType) + name := c.createConst(ssa.NewConst(constant.MakeString(file.Name), types.Typ[types.String])) + fileStruct = llvm.ConstInsertValue(fileStruct, name, []uint32{0}) // "name" field + if file.Hash != "" { + data := c.getEmbedFileString(file) + fileStruct = llvm.ConstInsertValue(fileStruct, data, []uint32{1}) // "data" field + } + fileStructs = append(fileStructs, fileStruct) + } + sliceDataInitializer := llvm.ConstArray(embedFileStructType, fileStructs) + sliceDataGlobal := llvm.AddGlobal(c.mod, sliceDataInitializer.Type(), c.pkg.Path()+"$embedfsfiles") + sliceDataGlobal.SetInitializer(sliceDataInitializer) + sliceDataGlobal.SetLinkage(llvm.InternalLinkage) + sliceDataGlobal.SetGlobalConstant(true) + sliceDataGlobal.SetUnnamedAddr(true) + sliceDataGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceDataInitializer.Type())) + + // Create the slice object itself. + // Because embed.FS refers to it as *[]embed.file instead of a plain + // []embed.file, we have to store this as a global. + slicePtr := llvm.ConstInBoundsGEP(sliceDataGlobal, []llvm.Value{ + llvm.ConstInt(c.uintptrType, 0, false), + llvm.ConstInt(c.uintptrType, 0, false), + }) + sliceLen := llvm.ConstInt(c.uintptrType, uint64(len(fileStructs)), false) + sliceInitializer := c.ctx.ConstStruct([]llvm.Value{slicePtr, sliceLen, sliceLen}, false) + sliceGlobal := llvm.AddGlobal(c.mod, sliceInitializer.Type(), c.pkg.Path()+"$embedfsslice") + sliceGlobal.SetInitializer(sliceInitializer) + sliceGlobal.SetLinkage(llvm.InternalLinkage) + sliceGlobal.SetGlobalConstant(true) + sliceGlobal.SetUnnamedAddr(true) + sliceGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceInitializer.Type())) + + // Define the embed.FS struct. It has only one field: the files (as a + // *[]embed.file). + globalInitializer := llvm.ConstNull(c.getLLVMType(member.Type().(*types.Pointer).Elem())) + globalInitializer = llvm.ConstInsertValue(globalInitializer, sliceGlobal, []uint32{0}) + global.SetInitializer(globalInitializer) + global.SetVisibility(llvm.HiddenVisibility) + global.SetAlignment(c.targetData.ABITypeAlignment(globalInitializer.Type())) + } +} + +// getEmbedFileString returns the (constant) string object with the contents of +// the given file. This is a llvm.Value of a regular Go string. +func (c *compilerContext) getEmbedFileString(file *loader.EmbedFile) llvm.Value { + dataGlobalName := "embed/file_" + file.Hash + dataGlobal := c.mod.NamedGlobal(dataGlobalName) + if dataGlobal.IsNil() { + dataGlobalType := llvm.ArrayType(c.ctx.Int8Type(), int(file.Size)) + dataGlobal = llvm.AddGlobal(c.mod, dataGlobalType, dataGlobalName) + } + strPtr := llvm.ConstInBoundsGEP(dataGlobal, []llvm.Value{ + llvm.ConstInt(c.uintptrType, 0, false), + llvm.ConstInt(c.uintptrType, 0, false), + }) + strLen := llvm.ConstInt(c.uintptrType, file.Size, false) + return llvm.ConstNamedStruct(c.getLLVMRuntimeType("_string"), []llvm.Value{strPtr, strLen}) +} + // createFunction builds the LLVM IR implementation for this function. The // function must not yet be defined, otherwise this function will create a // diagnostic. diff --git a/loader/loader.go b/loader/loader.go index e9de2ec2..f0221942 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "go/ast" + "go/constant" "go/parser" "go/scanner" "go/token" @@ -15,10 +16,12 @@ import ( "io/ioutil" "os" "os/exec" + "path" "path/filepath" "runtime" "strconv" "strings" + "unicode" "github.com/tinygo-org/tinygo/cgo" "github.com/tinygo-org/tinygo/compileopts" @@ -61,6 +64,9 @@ type PackageJSON struct { CgoFiles []string CFiles []string + // Embedded files + EmbedFiles []string + // Dependency information Imports []string ImportMap map[string]string @@ -77,13 +83,22 @@ type PackageJSON struct { type Package struct { PackageJSON - program *Program - Files []*ast.File - FileHashes map[string][]byte - CFlags []string // CFlags used during CGo preprocessing (only set if CGo is used) - CGoHeaders []string // text above 'import "C"' lines - Pkg *types.Package - info types.Info + program *Program + Files []*ast.File + FileHashes map[string][]byte + CFlags []string // CFlags used during CGo preprocessing (only set if CGo is used) + CGoHeaders []string // text above 'import "C"' lines + EmbedGlobals map[string][]*EmbedFile + Pkg *types.Package + info types.Info +} + +type EmbedFile struct { + Name string + Size uint64 + Hash string // hash of the file (as a hex string) + NeedsData bool // true if this file is embedded as a byte slice + Data []byte // contents of this file (only if NeedsData is set) } // Load loads the given package with all dependencies (including the runtime @@ -137,8 +152,9 @@ func Load(config *compileopts.Config, inputPkgs []string, clangHeaders string, t decoder := json.NewDecoder(buf) for { pkg := &Package{ - program: p, - FileHashes: make(map[string][]byte), + program: p, + FileHashes: make(map[string][]byte), + EmbedGlobals: make(map[string][]*EmbedFile), info: types.Info{ Types: make(map[ast.Expr]types.TypeAndValue), Defs: make(map[*ast.Ident]types.Object), @@ -357,15 +373,15 @@ func (p *Package) Check() error { return nil // already typechecked } + // Prepare some state used during type checking. var typeErrors []error checker := p.program.typeChecker // make a copy, because it will be modified checker.Error = func(err error) { typeErrors = append(typeErrors, err) } - - // Do typechecking of the package. checker.Importer = p + // Do typechecking of the package. packageName := p.ImportPath if p == p.program.MainPkg() { if p.Name != "main" { @@ -382,6 +398,12 @@ func (p *Package) Check() error { return Errors{p, typeErrors} } p.Pkg = typesPkg + + p.extractEmbedLines(checker.Error) + if len(typeErrors) != 0 { + return Errors{p, typeErrors} + } + return nil } @@ -440,6 +462,249 @@ func (p *Package) parseFiles() ([]*ast.File, error) { return files, nil } +// extractEmbedLines finds all //go:embed lines in the package and matches them +// against EmbedFiles from `go list`. +func (p *Package) extractEmbedLines(addError func(error)) { + for _, file := range p.Files { + // Check for an `import "embed"` line at the start of the file. + // //go:embed lines are only valid if the given file itself imports the + // embed package. It is not valid if it is only imported in a separate + // Go file. + hasEmbed := false + for _, importSpec := range file.Imports { + if importSpec.Path.Value == `"embed"` { + hasEmbed = true + } + } + + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.GenDecl: + if decl.Tok != token.VAR { + continue + } + for _, spec := range decl.Specs { + spec := spec.(*ast.ValueSpec) + var doc *ast.CommentGroup + if decl.Lparen == token.NoPos { + // Plain 'var' declaration, like: + // //go:embed hello.txt + // var hello string + doc = decl.Doc + } else { + // Bigger 'var' declaration like: + // var ( + // //go:embed hello.txt + // hello string + // ) + doc = spec.Doc + } + if doc == nil { + continue + } + + // Look for //go:embed comments. + var allPatterns []string + for _, comment := range doc.List { + if comment.Text != "//go:embed" && !strings.HasPrefix(comment.Text, "//go:embed ") { + continue + } + if !hasEmbed { + addError(types.Error{ + Fset: p.program.fset, + Pos: comment.Pos() + 2, + Msg: "//go:embed only allowed in Go files that import \"embed\"", + }) + // Continue, because otherwise we might run into + // issues below. + continue + } + patterns, err := p.parseGoEmbed(comment.Text[len("//go:embed"):], comment.Slash) + if err != nil { + addError(err) + continue + } + if len(patterns) == 0 { + addError(types.Error{ + Fset: p.program.fset, + Pos: comment.Pos() + 2, + Msg: "usage: //go:embed pattern...", + }) + continue + } + for _, pattern := range patterns { + // Check that the pattern is well-formed. + // It must be valid: the Go toolchain has already + // checked for invalid patterns. But let's check + // anyway to be sure. + if _, err := path.Match(pattern, ""); err != nil { + addError(types.Error{ + Fset: p.program.fset, + Pos: comment.Pos(), + Msg: "invalid pattern syntax", + }) + continue + } + allPatterns = append(allPatterns, pattern) + } + } + + if len(allPatterns) != 0 { + // This is a //go:embed global. Do a few more checks. + if len(spec.Names) != 1 { + addError(types.Error{ + Fset: p.program.fset, + Pos: spec.Names[1].NamePos, + Msg: "//go:embed cannot apply to multiple vars", + }) + } + if spec.Values != nil { + addError(types.Error{ + Fset: p.program.fset, + Pos: spec.Values[0].Pos(), + Msg: "//go:embed cannot apply to var with initializer", + }) + } + globalName := spec.Names[0].Name + globalType := p.Pkg.Scope().Lookup(globalName).Type() + valid, byteSlice := isValidEmbedType(globalType) + if !valid { + addError(types.Error{ + Fset: p.program.fset, + Pos: spec.Type.Pos(), + Msg: "//go:embed cannot apply to var of type " + globalType.String(), + }) + } + + // Match all //go:embed patterns against the embed files + // provided by `go list`. + for _, name := range p.EmbedFiles { + for _, pattern := range allPatterns { + if matchPattern(pattern, name) { + p.EmbedGlobals[globalName] = append(p.EmbedGlobals[globalName], &EmbedFile{ + Name: name, + NeedsData: byteSlice, + }) + break + } + } + } + } + } + } + } + } +} + +// matchPattern returns true if (and only if) the given pattern would match the +// filename. The pattern could also match a parent directory of name, in which +// case hidden files do not match. +func matchPattern(pattern, name string) bool { + // Match this file. + matched, _ := path.Match(pattern, name) + if matched { + return true + } + + // Match parent directories. + dir := name + for { + dir, _ = path.Split(dir) + if dir == "" { + return false + } + dir = path.Clean(dir) + if matched, _ := path.Match(pattern, dir); matched { + // Pattern matches the directory. + suffix := name[len(dir):] + if strings.Contains(suffix, "/_") || strings.Contains(suffix, "/.") { + // Pattern matches a hidden file. + // Hidden files are included when listed directly as a + // pattern, but not when they are part of a directory tree. + // Source: + // > If a pattern names a directory, all files in the + // > subtree rooted at that directory are embedded + // > (recursively), except that files with names beginning + // > with ‘.’ or ‘_’ are excluded. + return false + } + return true + } + } +} + +// parseGoEmbed is like strings.Fields but for a //go:embed line. It parses +// regular fields and quoted fields (that may contain spaces). +func (p *Package) parseGoEmbed(args string, pos token.Pos) (patterns []string, err error) { + args = strings.TrimSpace(args) + initialLen := len(args) + for args != "" { + patternPos := pos + token.Pos(initialLen-len(args)) + switch args[0] { + case '`', '"', '\\': + // Parse the next pattern using the Go scanner. + // This is perhaps a bit overkill, but it does correctly implement + // parsing of the various Go strings. + var sc scanner.Scanner + fset := &token.FileSet{} + file := fset.AddFile("", 0, len(args)) + sc.Init(file, []byte(args), nil, 0) + _, tok, lit := sc.Scan() + if tok != token.STRING || sc.ErrorCount != 0 { + // Calculate start of token + return nil, types.Error{ + Fset: p.program.fset, + Pos: patternPos, + Msg: "invalid quoted string in //go:embed", + } + } + pattern := constant.StringVal(constant.MakeFromLiteral(lit, tok, 0)) + patterns = append(patterns, pattern) + args = strings.TrimLeftFunc(args[len(lit):], unicode.IsSpace) + default: + // The value is just a regular value. + // Split it at the first white space. + index := strings.IndexFunc(args, unicode.IsSpace) + if index < 0 { + index = len(args) + } + pattern := args[:index] + patterns = append(patterns, pattern) + args = strings.TrimLeftFunc(args[len(pattern):], unicode.IsSpace) + } + if _, err := path.Match(patterns[len(patterns)-1], ""); err != nil { + return nil, types.Error{ + Fset: p.program.fset, + Pos: patternPos, + Msg: "invalid pattern syntax", + } + } + } + return patterns, nil +} + +// isValidEmbedType returns whether the given Go type can be used as a +// //go:embed type. This is only true for embed.FS, strings, and byte slices. +// The second return value indicates that this is a byte slice, and therefore +// the contents of the file needs to be passed to the compiler. +func isValidEmbedType(typ types.Type) (valid, byteSlice bool) { + if typ.Underlying() == types.Typ[types.String] { + // string type + return true, false + } + if sliceType, ok := typ.Underlying().(*types.Slice); ok { + if elemType, ok := sliceType.Elem().Underlying().(*types.Basic); ok && elemType.Kind() == types.Byte { + // byte slice type + return true, true + } + } + if namedType, ok := typ.(*types.Named); ok && namedType.String() == "embed.FS" { + // embed.FS type + return true, false + } + return false, false +} + // Import implements types.Importer. It loads and parses packages it encounters // along the way, if needed. func (p *Package) Import(to string) (*types.Package, error) { diff --git a/main_test.go b/main_test.go index e192d475..3351b6c0 100644 --- a/main_test.go +++ b/main_test.go @@ -48,6 +48,7 @@ func TestBuild(t *testing.T) { "calls.go", "cgo/", "channel.go", + "embed/", "float.go", "gc.go", "goroutines.go", diff --git a/testdata/embed/a/b/.hidden b/testdata/embed/a/b/.hidden new file mode 100644 index 00000000..e69de29b diff --git a/testdata/embed/a/b/bar.txt b/testdata/embed/a/b/bar.txt new file mode 100644 index 00000000..5716ca59 --- /dev/null +++ b/testdata/embed/a/b/bar.txt @@ -0,0 +1 @@ +bar diff --git a/testdata/embed/a/b/foo.txt b/testdata/embed/a/b/foo.txt new file mode 100644 index 00000000..257cc564 --- /dev/null +++ b/testdata/embed/a/b/foo.txt @@ -0,0 +1 @@ +foo diff --git a/testdata/embed/embed.go b/testdata/embed/embed.go new file mode 100644 index 00000000..da658718 --- /dev/null +++ b/testdata/embed/embed.go @@ -0,0 +1,46 @@ +package main + +import ( + "embed" + "strings" +) + +//go:embed a hello.txt +var files embed.FS + +var ( + //go:embed "hello.*" + helloString string + + //go:embed hello.txt + helloBytes []byte +) + +// A test to check that hidden files are not included when matching a directory. +//go:embed a/b/.hidden +var hidden string + +func main() { + println("string:", strings.TrimSpace(helloString)) + println("bytes:", strings.TrimSpace(string(helloBytes))) + println("files:") + readFiles(".") +} + +func readFiles(dir string) { + entries, err := files.ReadDir(dir) + if err != nil { + println(err.Error()) + return + } + for _, entry := range entries { + entryPath := entry.Name() + if dir != "." { + entryPath = dir + "/" + entryPath + } + println("-", entryPath) + if entry.IsDir() { + readFiles(entryPath) + } + } +} diff --git a/testdata/embed/hello.txt b/testdata/embed/hello.txt new file mode 100644 index 00000000..a0423896 --- /dev/null +++ b/testdata/embed/hello.txt @@ -0,0 +1 @@ +hello world! diff --git a/testdata/embed/out.txt b/testdata/embed/out.txt new file mode 100644 index 00000000..3b216da7 --- /dev/null +++ b/testdata/embed/out.txt @@ -0,0 +1,8 @@ +string: hello world! +bytes: hello world! +files: +- a +- a/b +- a/b/bar.txt +- a/b/foo.txt +- hello.txt