all: add support for the embed package

Этот коммит содержится в:
Ayke van Laethem 2021-11-15 01:56:49 +01:00 коммит произвёл Ron Evans
родитель fd20f63ee3
коммит 87a4676137
12 изменённых файлов: 645 добавлений и 14 удалений

Просмотреть файл

@ -265,6 +265,7 @@ TEST_PACKAGES_FAST = \
crypto/sha256 \ crypto/sha256 \
crypto/sha512 \ crypto/sha512 \
debug/macho \ debug/macho \
embed/internal/embedtest \
encoding \ encoding \
encoding/ascii85 \ encoding/ascii85 \
encoding/base32 \ encoding/base32 \

Просмотреть файл

@ -4,6 +4,7 @@
package builder package builder
import ( import (
"crypto/sha256"
"crypto/sha512" "crypto/sha512"
"debug/elf" "debug/elf"
"encoding/binary" "encoding/binary"
@ -80,6 +81,7 @@ type packageAction struct {
Config *compiler.Config Config *compiler.Config
CFlags []string CFlags []string
FileHashes map[string]string // hash of every file that's part of the package FileHashes map[string]string // hash of every file that's part of the package
EmbeddedFiles map[string]string // hash of all the //go:embed files in the package
Imports map[string]string // map from imported package to action ID hash Imports map[string]string // map from imported package to action ID hash
OptLevel int // LLVM optimization level (0-3) OptLevel int // LLVM optimization level (0-3)
SizeLevel int // LLVM optimization for size level (0-2) SizeLevel int // LLVM optimization for size level (0-2)
@ -225,6 +227,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
config.Options.GlobalValues["runtime"]["buildVersion"] = version config.Options.GlobalValues["runtime"]["buildVersion"] = version
} }
var embedFileObjects []*compileJob
for _, pkg := range lprogram.Sorted() { for _, pkg := range lprogram.Sorted() {
pkg := pkg // necessary to avoid a race condition pkg := pkg // necessary to avoid a race condition
@ -234,6 +237,47 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
} }
sort.Strings(undefinedGlobals) sort.Strings(undefinedGlobals)
// Make compile jobs to load files to be embedded in the output binary.
var actionIDDependencies []*compileJob
allFiles := map[string][]*loader.EmbedFile{}
for _, files := range pkg.EmbedGlobals {
for _, file := range files {
allFiles[file.Name] = append(allFiles[file.Name], file)
}
}
for name, files := range allFiles {
name := name
files := files
job := &compileJob{
description: "make object file for " + name,
run: func(job *compileJob) error {
// Read the file contents in memory.
path := filepath.Join(pkg.Dir, name)
data, err := os.ReadFile(path)
if err != nil {
return err
}
// Hash the file.
sum := sha256.Sum256(data)
hexSum := hex.EncodeToString(sum[:16])
for _, file := range files {
file.Size = uint64(len(data))
file.Hash = hexSum
if file.NeedsData {
file.Data = data
}
}
job.result, err = createEmbedObjectFile(string(data), hexSum, name, pkg.OriginalDir(), dir, compilerConfig)
return err
},
}
actionIDDependencies = append(actionIDDependencies, job)
embedFileObjects = append(embedFileObjects, job)
}
// Action ID jobs need to know the action ID of all the jobs the package // Action ID jobs need to know the action ID of all the jobs the package
// imports. // imports.
var importedPackages []*compileJob var importedPackages []*compileJob
@ -243,6 +287,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
return fmt.Errorf("package %s imports %s but couldn't find dependency", pkg.ImportPath, imported.Path()) return fmt.Errorf("package %s imports %s but couldn't find dependency", pkg.ImportPath, imported.Path())
} }
importedPackages = append(importedPackages, job) importedPackages = append(importedPackages, job)
actionIDDependencies = append(actionIDDependencies, job)
} }
// Create a job that will calculate the action ID for a package compile // Create a job that will calculate the action ID for a package compile
@ -250,7 +295,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
// package. // package.
packageActionIDJob := &compileJob{ packageActionIDJob := &compileJob{
description: "calculate cache key for package " + pkg.ImportPath, description: "calculate cache key for package " + pkg.ImportPath,
dependencies: importedPackages, dependencies: actionIDDependencies,
run: func(job *compileJob) error { run: func(job *compileJob) error {
// Create a cache key: a hash from the action ID below that contains all // Create a cache key: a hash from the action ID below that contains all
// the parameters for the build. // the parameters for the build.
@ -262,6 +307,7 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
Config: compilerConfig, Config: compilerConfig,
CFlags: pkg.CFlags, CFlags: pkg.CFlags,
FileHashes: make(map[string]string, len(pkg.FileHashes)), FileHashes: make(map[string]string, len(pkg.FileHashes)),
EmbeddedFiles: make(map[string]string, len(allFiles)),
Imports: make(map[string]string, len(pkg.Pkg.Imports())), Imports: make(map[string]string, len(pkg.Pkg.Imports())),
OptLevel: optLevel, OptLevel: optLevel,
SizeLevel: sizeLevel, SizeLevel: sizeLevel,
@ -270,6 +316,9 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
for filePath, hash := range pkg.FileHashes { for filePath, hash := range pkg.FileHashes {
actionID.FileHashes[filePath] = hex.EncodeToString(hash) actionID.FileHashes[filePath] = hex.EncodeToString(hash)
} }
for name, files := range allFiles {
actionID.EmbeddedFiles[name] = files[0].Hash
}
for i, imported := range pkg.Pkg.Imports() { for i, imported := range pkg.Pkg.Imports() {
actionID.Imports[imported.Path()] = importedPackages[i].result actionID.Imports[imported.Path()] = importedPackages[i].result
} }
@ -668,6 +717,9 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
// Add libc dependencies, if they exist. // Add libc dependencies, if they exist.
linkerDependencies = append(linkerDependencies, libcDependencies...) linkerDependencies = append(linkerDependencies, libcDependencies...)
// Add embedded files.
linkerDependencies = append(linkerDependencies, embedFileObjects...)
// Strip debug information with -no-debug. // Strip debug information with -no-debug.
if !config.Debug() { if !config.Debug() {
for _, tag := range config.BuildTags() { for _, tag := range config.BuildTags() {
@ -920,6 +972,112 @@ func Build(pkgName, outpath string, config *compileopts.Config, action func(Buil
}) })
} }
// createEmbedObjectFile creates a new object file with the given contents, for
// the embed package.
func createEmbedObjectFile(data, hexSum, sourceFile, sourceDir, tmpdir string, compilerConfig *compiler.Config) (string, error) {
// TODO: this works for small files, but can be a problem for larger files.
// For larger files, it seems more appropriate to generate the object file
// manually without going through LLVM.
// On the other hand, generating DWARF like we do here can be difficult
// without assistance from LLVM.
// Create new LLVM module just for this file.
ctx := llvm.NewContext()
defer ctx.Dispose()
mod := ctx.NewModule("data")
defer mod.Dispose()
// Create data global.
value := ctx.ConstString(data, false)
globalName := "embed/file_" + hexSum
global := llvm.AddGlobal(mod, value.Type(), globalName)
global.SetInitializer(value)
global.SetLinkage(llvm.LinkOnceODRLinkage)
global.SetGlobalConstant(true)
global.SetUnnamedAddr(true)
global.SetAlignment(1)
if compilerConfig.GOOS != "darwin" {
// MachO doesn't support COMDATs, while COFF requires it (to avoid
// "duplicate symbol" errors). ELF works either way.
// Therefore, only use a COMDAT on non-MachO systems (aka non-MacOS).
global.SetComdat(mod.Comdat(globalName))
}
// Add DWARF debug information to this global, so that it is
// correctly counted when compiling with the -size= flag.
dibuilder := llvm.NewDIBuilder(mod)
dibuilder.CreateCompileUnit(llvm.DICompileUnit{
Language: 0xb, // DW_LANG_C99 (0xc, off-by-one?)
File: sourceFile,
Dir: sourceDir,
Producer: "TinyGo",
Optimized: false,
})
ditype := dibuilder.CreateArrayType(llvm.DIArrayType{
SizeInBits: uint64(len(data)) * 8,
AlignInBits: 8,
ElementType: dibuilder.CreateBasicType(llvm.DIBasicType{
Name: "byte",
SizeInBits: 8,
Encoding: llvm.DW_ATE_unsigned_char,
}),
Subscripts: []llvm.DISubrange{
{
Lo: 0,
Count: int64(len(data)),
},
},
})
difile := dibuilder.CreateFile(sourceFile, sourceDir)
diglobalexpr := dibuilder.CreateGlobalVariableExpression(difile, llvm.DIGlobalVariableExpression{
Name: globalName,
File: difile,
Line: 1,
Type: ditype,
Expr: dibuilder.CreateExpression(nil),
AlignInBits: 8,
})
global.AddMetadata(0, diglobalexpr)
mod.AddNamedMetadataOperand("llvm.module.flags",
ctx.MDNode([]llvm.Metadata{
llvm.ConstInt(ctx.Int32Type(), 2, false).ConstantAsMetadata(), // Warning on mismatch
ctx.MDString("Debug Info Version"),
llvm.ConstInt(ctx.Int32Type(), 3, false).ConstantAsMetadata(),
}),
)
mod.AddNamedMetadataOperand("llvm.module.flags",
ctx.MDNode([]llvm.Metadata{
llvm.ConstInt(ctx.Int32Type(), 7, false).ConstantAsMetadata(), // Max on mismatch
ctx.MDString("Dwarf Version"),
llvm.ConstInt(ctx.Int32Type(), 4, false).ConstantAsMetadata(),
}),
)
dibuilder.Finalize()
dibuilder.Destroy()
// Write this LLVM module out as an object file.
machine, err := compiler.NewTargetMachine(compilerConfig)
if err != nil {
return "", err
}
defer machine.Dispose()
outfile, err := os.CreateTemp(tmpdir, "embed-"+hexSum+"-*.o")
if err != nil {
return "", err
}
defer outfile.Close()
buf, err := machine.EmitToMemoryBuffer(mod, llvm.ObjectFile)
if err != nil {
return "", err
}
defer buf.Dispose()
_, err = outfile.Write(buf.Bytes())
if err != nil {
return "", err
}
return outfile.Name(), outfile.Close()
}
// optimizeProgram runs a series of optimizations and transformations that are // optimizeProgram runs a series of optimizations and transformations that are
// needed to convert a program to its final form. Some transformations are not // needed to convert a program to its final form. Some transformations are not
// optional and must be run as the compiler expects them to run. // optional and must be run as the compiler expects them to run.

Просмотреть файл

@ -117,7 +117,7 @@ var (
// alloc: heap allocations during init interpretation // alloc: heap allocations during init interpretation
// pack: data created when storing a constant in an interface for example // pack: data created when storing a constant in an interface for example
// string: buffer behind strings // string: buffer behind strings
packageSymbolRegexp = regexp.MustCompile(`\$(alloc|pack|string)(\.[0-9]+)?$`) packageSymbolRegexp = regexp.MustCompile(`\$(alloc|embedfsfiles|embedfsslice|embedslice|pack|string)(\.[0-9]+)?$`)
// Reflect sidetables. Created by the reflect lowering pass. // Reflect sidetables. Created by the reflect lowering pass.
// See src/reflect/sidetables.go. // See src/reflect/sidetables.go.

Просмотреть файл

@ -9,6 +9,7 @@ import (
"go/token" "go/token"
"go/types" "go/types"
"math/bits" "math/bits"
"path"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv" "strconv"
@ -76,6 +77,7 @@ type compilerContext struct {
program *ssa.Program program *ssa.Program
diagnostics []error diagnostics []error
astComments map[string]*ast.CommentGroup astComments map[string]*ast.CommentGroup
embedGlobals map[string][]*loader.EmbedFile
pkg *types.Package pkg *types.Package
packageDir string // directory for this package packageDir string // directory for this package
runtimePkg *types.Package runtimePkg *types.Package
@ -250,6 +252,7 @@ func Sizes(machine llvm.TargetMachine) types.Sizes {
func CompilePackage(moduleName string, pkg *loader.Package, ssaPkg *ssa.Package, machine llvm.TargetMachine, config *Config, dumpSSA bool) (llvm.Module, []error) { func CompilePackage(moduleName string, pkg *loader.Package, ssaPkg *ssa.Package, machine llvm.TargetMachine, config *Config, dumpSSA bool) (llvm.Module, []error) {
c := newCompilerContext(moduleName, machine, config, dumpSSA) c := newCompilerContext(moduleName, machine, config, dumpSSA)
c.packageDir = pkg.OriginalDir() c.packageDir = pkg.OriginalDir()
c.embedGlobals = pkg.EmbedGlobals
c.pkg = pkg.Pkg c.pkg = pkg.Pkg
c.runtimePkg = ssaPkg.Prog.ImportedPackage("runtime").Pkg c.runtimePkg = ssaPkg.Prog.ImportedPackage("runtime").Pkg
c.program = ssaPkg.Prog c.program = ssaPkg.Prog
@ -815,7 +818,9 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package
// Global variable. // Global variable.
info := c.getGlobalInfo(member) info := c.getGlobalInfo(member)
global := c.getGlobal(member) global := c.getGlobal(member)
if !info.extern { if files, ok := c.embedGlobals[member.Name()]; ok {
c.createEmbedGlobal(member, global, files)
} else if !info.extern {
global.SetInitializer(llvm.ConstNull(global.Type().ElementType())) global.SetInitializer(llvm.ConstNull(global.Type().ElementType()))
global.SetVisibility(llvm.HiddenVisibility) global.SetVisibility(llvm.HiddenVisibility)
if info.section != "" { if info.section != "" {
@ -849,6 +854,150 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package
} }
} }
// createEmbedGlobal creates an initializer for a //go:embed global variable.
func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Value, files []*loader.EmbedFile) {
switch typ := member.Type().(*types.Pointer).Elem().Underlying().(type) {
case *types.Basic:
// String type.
if typ.Kind() != types.String {
// This is checked at the AST level, so should be unreachable.
panic("expected a string type")
}
if len(files) != 1 {
c.addError(member.Pos(), fmt.Sprintf("//go:embed for a string should be given exactly one file, got %d", len(files)))
return
}
strObj := c.getEmbedFileString(files[0])
global.SetInitializer(strObj)
global.SetVisibility(llvm.HiddenVisibility)
case *types.Slice:
if typ.Elem().Underlying().(*types.Basic).Kind() != types.Byte {
// This is checked at the AST level, so should be unreachable.
panic("expected a byte slice")
}
if len(files) != 1 {
c.addError(member.Pos(), fmt.Sprintf("//go:embed for a string should be given exactly one file, got %d", len(files)))
return
}
file := files[0]
bufferValue := c.ctx.ConstString(string(file.Data), false)
bufferGlobal := llvm.AddGlobal(c.mod, bufferValue.Type(), c.pkg.Path()+"$embedslice")
bufferGlobal.SetInitializer(bufferValue)
bufferGlobal.SetLinkage(llvm.InternalLinkage)
bufferGlobal.SetAlignment(1)
slicePtr := llvm.ConstInBoundsGEP(bufferGlobal, []llvm.Value{
llvm.ConstInt(c.uintptrType, 0, false),
llvm.ConstInt(c.uintptrType, 0, false),
})
sliceLen := llvm.ConstInt(c.uintptrType, file.Size, false)
sliceObj := c.ctx.ConstStruct([]llvm.Value{slicePtr, sliceLen, sliceLen}, false)
global.SetInitializer(sliceObj)
global.SetVisibility(llvm.HiddenVisibility)
case *types.Struct:
// Assume this is an embed.FS struct:
// https://cs.opensource.google/go/go/+/refs/tags/go1.18.2:src/embed/embed.go;l=148
// It looks like this:
// type FS struct {
// files *file
// }
// Make a slice of the files, as they will appear in the binary. They
// are sorted in a special way to allow for binary searches, see
// src/embed/embed.go for details.
dirset := map[string]struct{}{}
var allFiles []*loader.EmbedFile
for _, file := range files {
allFiles = append(allFiles, file)
dirname := file.Name
for {
dirname, _ = path.Split(path.Clean(dirname))
if dirname == "" {
break
}
if _, ok := dirset[dirname]; ok {
break
}
dirset[dirname] = struct{}{}
allFiles = append(allFiles, &loader.EmbedFile{
Name: dirname,
})
}
}
sort.Slice(allFiles, func(i, j int) bool {
dir1, name1 := path.Split(path.Clean(allFiles[i].Name))
dir2, name2 := path.Split(path.Clean(allFiles[j].Name))
if dir1 != dir2 {
return dir1 < dir2
}
return name1 < name2
})
// Make the backing array for the []files slice. This is a LLVM global.
embedFileStructType := c.getLLVMType(typ.Field(0).Type().(*types.Pointer).Elem().(*types.Slice).Elem())
var fileStructs []llvm.Value
for _, file := range allFiles {
fileStruct := llvm.ConstNull(embedFileStructType)
name := c.createConst(ssa.NewConst(constant.MakeString(file.Name), types.Typ[types.String]))
fileStruct = llvm.ConstInsertValue(fileStruct, name, []uint32{0}) // "name" field
if file.Hash != "" {
data := c.getEmbedFileString(file)
fileStruct = llvm.ConstInsertValue(fileStruct, data, []uint32{1}) // "data" field
}
fileStructs = append(fileStructs, fileStruct)
}
sliceDataInitializer := llvm.ConstArray(embedFileStructType, fileStructs)
sliceDataGlobal := llvm.AddGlobal(c.mod, sliceDataInitializer.Type(), c.pkg.Path()+"$embedfsfiles")
sliceDataGlobal.SetInitializer(sliceDataInitializer)
sliceDataGlobal.SetLinkage(llvm.InternalLinkage)
sliceDataGlobal.SetGlobalConstant(true)
sliceDataGlobal.SetUnnamedAddr(true)
sliceDataGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceDataInitializer.Type()))
// Create the slice object itself.
// Because embed.FS refers to it as *[]embed.file instead of a plain
// []embed.file, we have to store this as a global.
slicePtr := llvm.ConstInBoundsGEP(sliceDataGlobal, []llvm.Value{
llvm.ConstInt(c.uintptrType, 0, false),
llvm.ConstInt(c.uintptrType, 0, false),
})
sliceLen := llvm.ConstInt(c.uintptrType, uint64(len(fileStructs)), false)
sliceInitializer := c.ctx.ConstStruct([]llvm.Value{slicePtr, sliceLen, sliceLen}, false)
sliceGlobal := llvm.AddGlobal(c.mod, sliceInitializer.Type(), c.pkg.Path()+"$embedfsslice")
sliceGlobal.SetInitializer(sliceInitializer)
sliceGlobal.SetLinkage(llvm.InternalLinkage)
sliceGlobal.SetGlobalConstant(true)
sliceGlobal.SetUnnamedAddr(true)
sliceGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceInitializer.Type()))
// Define the embed.FS struct. It has only one field: the files (as a
// *[]embed.file).
globalInitializer := llvm.ConstNull(c.getLLVMType(member.Type().(*types.Pointer).Elem()))
globalInitializer = llvm.ConstInsertValue(globalInitializer, sliceGlobal, []uint32{0})
global.SetInitializer(globalInitializer)
global.SetVisibility(llvm.HiddenVisibility)
global.SetAlignment(c.targetData.ABITypeAlignment(globalInitializer.Type()))
}
}
// getEmbedFileString returns the (constant) string object with the contents of
// the given file. This is a llvm.Value of a regular Go string.
func (c *compilerContext) getEmbedFileString(file *loader.EmbedFile) llvm.Value {
dataGlobalName := "embed/file_" + file.Hash
dataGlobal := c.mod.NamedGlobal(dataGlobalName)
if dataGlobal.IsNil() {
dataGlobalType := llvm.ArrayType(c.ctx.Int8Type(), int(file.Size))
dataGlobal = llvm.AddGlobal(c.mod, dataGlobalType, dataGlobalName)
}
strPtr := llvm.ConstInBoundsGEP(dataGlobal, []llvm.Value{
llvm.ConstInt(c.uintptrType, 0, false),
llvm.ConstInt(c.uintptrType, 0, false),
})
strLen := llvm.ConstInt(c.uintptrType, file.Size, false)
return llvm.ConstNamedStruct(c.getLLVMRuntimeType("_string"), []llvm.Value{strPtr, strLen})
}
// createFunction builds the LLVM IR implementation for this function. The // createFunction builds the LLVM IR implementation for this function. The
// function must not yet be defined, otherwise this function will create a // function must not yet be defined, otherwise this function will create a
// diagnostic. // diagnostic.

Просмотреть файл

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/constant"
"go/parser" "go/parser"
"go/scanner" "go/scanner"
"go/token" "go/token"
@ -15,10 +16,12 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"unicode"
"github.com/tinygo-org/tinygo/cgo" "github.com/tinygo-org/tinygo/cgo"
"github.com/tinygo-org/tinygo/compileopts" "github.com/tinygo-org/tinygo/compileopts"
@ -61,6 +64,9 @@ type PackageJSON struct {
CgoFiles []string CgoFiles []string
CFiles []string CFiles []string
// Embedded files
EmbedFiles []string
// Dependency information // Dependency information
Imports []string Imports []string
ImportMap map[string]string ImportMap map[string]string
@ -77,13 +83,22 @@ type PackageJSON struct {
type Package struct { type Package struct {
PackageJSON PackageJSON
program *Program program *Program
Files []*ast.File Files []*ast.File
FileHashes map[string][]byte FileHashes map[string][]byte
CFlags []string // CFlags used during CGo preprocessing (only set if CGo is used) CFlags []string // CFlags used during CGo preprocessing (only set if CGo is used)
CGoHeaders []string // text above 'import "C"' lines CGoHeaders []string // text above 'import "C"' lines
Pkg *types.Package EmbedGlobals map[string][]*EmbedFile
info types.Info Pkg *types.Package
info types.Info
}
type EmbedFile struct {
Name string
Size uint64
Hash string // hash of the file (as a hex string)
NeedsData bool // true if this file is embedded as a byte slice
Data []byte // contents of this file (only if NeedsData is set)
} }
// Load loads the given package with all dependencies (including the runtime // Load loads the given package with all dependencies (including the runtime
@ -137,8 +152,9 @@ func Load(config *compileopts.Config, inputPkgs []string, clangHeaders string, t
decoder := json.NewDecoder(buf) decoder := json.NewDecoder(buf)
for { for {
pkg := &Package{ pkg := &Package{
program: p, program: p,
FileHashes: make(map[string][]byte), FileHashes: make(map[string][]byte),
EmbedGlobals: make(map[string][]*EmbedFile),
info: types.Info{ info: types.Info{
Types: make(map[ast.Expr]types.TypeAndValue), Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object), Defs: make(map[*ast.Ident]types.Object),
@ -357,15 +373,15 @@ func (p *Package) Check() error {
return nil // already typechecked return nil // already typechecked
} }
// Prepare some state used during type checking.
var typeErrors []error var typeErrors []error
checker := p.program.typeChecker // make a copy, because it will be modified checker := p.program.typeChecker // make a copy, because it will be modified
checker.Error = func(err error) { checker.Error = func(err error) {
typeErrors = append(typeErrors, err) typeErrors = append(typeErrors, err)
} }
// Do typechecking of the package.
checker.Importer = p checker.Importer = p
// Do typechecking of the package.
packageName := p.ImportPath packageName := p.ImportPath
if p == p.program.MainPkg() { if p == p.program.MainPkg() {
if p.Name != "main" { if p.Name != "main" {
@ -382,6 +398,12 @@ func (p *Package) Check() error {
return Errors{p, typeErrors} return Errors{p, typeErrors}
} }
p.Pkg = typesPkg p.Pkg = typesPkg
p.extractEmbedLines(checker.Error)
if len(typeErrors) != 0 {
return Errors{p, typeErrors}
}
return nil return nil
} }
@ -440,6 +462,249 @@ func (p *Package) parseFiles() ([]*ast.File, error) {
return files, nil return files, nil
} }
// extractEmbedLines finds all //go:embed lines in the package and matches them
// against EmbedFiles from `go list`.
func (p *Package) extractEmbedLines(addError func(error)) {
for _, file := range p.Files {
// Check for an `import "embed"` line at the start of the file.
// //go:embed lines are only valid if the given file itself imports the
// embed package. It is not valid if it is only imported in a separate
// Go file.
hasEmbed := false
for _, importSpec := range file.Imports {
if importSpec.Path.Value == `"embed"` {
hasEmbed = true
}
}
for _, decl := range file.Decls {
switch decl := decl.(type) {
case *ast.GenDecl:
if decl.Tok != token.VAR {
continue
}
for _, spec := range decl.Specs {
spec := spec.(*ast.ValueSpec)
var doc *ast.CommentGroup
if decl.Lparen == token.NoPos {
// Plain 'var' declaration, like:
// //go:embed hello.txt
// var hello string
doc = decl.Doc
} else {
// Bigger 'var' declaration like:
// var (
// //go:embed hello.txt
// hello string
// )
doc = spec.Doc
}
if doc == nil {
continue
}
// Look for //go:embed comments.
var allPatterns []string
for _, comment := range doc.List {
if comment.Text != "//go:embed" && !strings.HasPrefix(comment.Text, "//go:embed ") {
continue
}
if !hasEmbed {
addError(types.Error{
Fset: p.program.fset,
Pos: comment.Pos() + 2,
Msg: "//go:embed only allowed in Go files that import \"embed\"",
})
// Continue, because otherwise we might run into
// issues below.
continue
}
patterns, err := p.parseGoEmbed(comment.Text[len("//go:embed"):], comment.Slash)
if err != nil {
addError(err)
continue
}
if len(patterns) == 0 {
addError(types.Error{
Fset: p.program.fset,
Pos: comment.Pos() + 2,
Msg: "usage: //go:embed pattern...",
})
continue
}
for _, pattern := range patterns {
// Check that the pattern is well-formed.
// It must be valid: the Go toolchain has already
// checked for invalid patterns. But let's check
// anyway to be sure.
if _, err := path.Match(pattern, ""); err != nil {
addError(types.Error{
Fset: p.program.fset,
Pos: comment.Pos(),
Msg: "invalid pattern syntax",
})
continue
}
allPatterns = append(allPatterns, pattern)
}
}
if len(allPatterns) != 0 {
// This is a //go:embed global. Do a few more checks.
if len(spec.Names) != 1 {
addError(types.Error{
Fset: p.program.fset,
Pos: spec.Names[1].NamePos,
Msg: "//go:embed cannot apply to multiple vars",
})
}
if spec.Values != nil {
addError(types.Error{
Fset: p.program.fset,
Pos: spec.Values[0].Pos(),
Msg: "//go:embed cannot apply to var with initializer",
})
}
globalName := spec.Names[0].Name
globalType := p.Pkg.Scope().Lookup(globalName).Type()
valid, byteSlice := isValidEmbedType(globalType)
if !valid {
addError(types.Error{
Fset: p.program.fset,
Pos: spec.Type.Pos(),
Msg: "//go:embed cannot apply to var of type " + globalType.String(),
})
}
// Match all //go:embed patterns against the embed files
// provided by `go list`.
for _, name := range p.EmbedFiles {
for _, pattern := range allPatterns {
if matchPattern(pattern, name) {
p.EmbedGlobals[globalName] = append(p.EmbedGlobals[globalName], &EmbedFile{
Name: name,
NeedsData: byteSlice,
})
break
}
}
}
}
}
}
}
}
}
// matchPattern returns true if (and only if) the given pattern would match the
// filename. The pattern could also match a parent directory of name, in which
// case hidden files do not match.
func matchPattern(pattern, name string) bool {
// Match this file.
matched, _ := path.Match(pattern, name)
if matched {
return true
}
// Match parent directories.
dir := name
for {
dir, _ = path.Split(dir)
if dir == "" {
return false
}
dir = path.Clean(dir)
if matched, _ := path.Match(pattern, dir); matched {
// Pattern matches the directory.
suffix := name[len(dir):]
if strings.Contains(suffix, "/_") || strings.Contains(suffix, "/.") {
// Pattern matches a hidden file.
// Hidden files are included when listed directly as a
// pattern, but not when they are part of a directory tree.
// Source:
// > If a pattern names a directory, all files in the
// > subtree rooted at that directory are embedded
// > (recursively), except that files with names beginning
// > with . or _ are excluded.
return false
}
return true
}
}
}
// parseGoEmbed is like strings.Fields but for a //go:embed line. It parses
// regular fields and quoted fields (that may contain spaces).
func (p *Package) parseGoEmbed(args string, pos token.Pos) (patterns []string, err error) {
args = strings.TrimSpace(args)
initialLen := len(args)
for args != "" {
patternPos := pos + token.Pos(initialLen-len(args))
switch args[0] {
case '`', '"', '\\':
// Parse the next pattern using the Go scanner.
// This is perhaps a bit overkill, but it does correctly implement
// parsing of the various Go strings.
var sc scanner.Scanner
fset := &token.FileSet{}
file := fset.AddFile("", 0, len(args))
sc.Init(file, []byte(args), nil, 0)
_, tok, lit := sc.Scan()
if tok != token.STRING || sc.ErrorCount != 0 {
// Calculate start of token
return nil, types.Error{
Fset: p.program.fset,
Pos: patternPos,
Msg: "invalid quoted string in //go:embed",
}
}
pattern := constant.StringVal(constant.MakeFromLiteral(lit, tok, 0))
patterns = append(patterns, pattern)
args = strings.TrimLeftFunc(args[len(lit):], unicode.IsSpace)
default:
// The value is just a regular value.
// Split it at the first white space.
index := strings.IndexFunc(args, unicode.IsSpace)
if index < 0 {
index = len(args)
}
pattern := args[:index]
patterns = append(patterns, pattern)
args = strings.TrimLeftFunc(args[len(pattern):], unicode.IsSpace)
}
if _, err := path.Match(patterns[len(patterns)-1], ""); err != nil {
return nil, types.Error{
Fset: p.program.fset,
Pos: patternPos,
Msg: "invalid pattern syntax",
}
}
}
return patterns, nil
}
// isValidEmbedType returns whether the given Go type can be used as a
// //go:embed type. This is only true for embed.FS, strings, and byte slices.
// The second return value indicates that this is a byte slice, and therefore
// the contents of the file needs to be passed to the compiler.
func isValidEmbedType(typ types.Type) (valid, byteSlice bool) {
if typ.Underlying() == types.Typ[types.String] {
// string type
return true, false
}
if sliceType, ok := typ.Underlying().(*types.Slice); ok {
if elemType, ok := sliceType.Elem().Underlying().(*types.Basic); ok && elemType.Kind() == types.Byte {
// byte slice type
return true, true
}
}
if namedType, ok := typ.(*types.Named); ok && namedType.String() == "embed.FS" {
// embed.FS type
return true, false
}
return false, false
}
// Import implements types.Importer. It loads and parses packages it encounters // Import implements types.Importer. It loads and parses packages it encounters
// along the way, if needed. // along the way, if needed.
func (p *Package) Import(to string) (*types.Package, error) { func (p *Package) Import(to string) (*types.Package, error) {

Просмотреть файл

@ -48,6 +48,7 @@ func TestBuild(t *testing.T) {
"calls.go", "calls.go",
"cgo/", "cgo/",
"channel.go", "channel.go",
"embed/",
"float.go", "float.go",
"gc.go", "gc.go",
"goroutines.go", "goroutines.go",

0
testdata/embed/a/b/.hidden предоставленный Обычный файл
Просмотреть файл

1
testdata/embed/a/b/bar.txt предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1 @@
bar

1
testdata/embed/a/b/foo.txt предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1 @@
foo

46
testdata/embed/embed.go предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1,46 @@
package main
import (
"embed"
"strings"
)
//go:embed a hello.txt
var files embed.FS
var (
//go:embed "hello.*"
helloString string
//go:embed hello.txt
helloBytes []byte
)
// A test to check that hidden files are not included when matching a directory.
//go:embed a/b/.hidden
var hidden string
func main() {
println("string:", strings.TrimSpace(helloString))
println("bytes:", strings.TrimSpace(string(helloBytes)))
println("files:")
readFiles(".")
}
func readFiles(dir string) {
entries, err := files.ReadDir(dir)
if err != nil {
println(err.Error())
return
}
for _, entry := range entries {
entryPath := entry.Name()
if dir != "." {
entryPath = dir + "/" + entryPath
}
println("-", entryPath)
if entry.IsDir() {
readFiles(entryPath)
}
}
}

1
testdata/embed/hello.txt предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1 @@
hello world!

8
testdata/embed/out.txt предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1,8 @@
string: hello world!
bytes: hello world!
files:
- a
- a/b
- a/b/bar.txt
- a/b/foo.txt
- hello.txt