diff --git a/README.markdown b/README.markdown index 6516abc6..54e778a0 100644 --- a/README.markdown +++ b/README.markdown @@ -45,12 +45,12 @@ Currently supported features: * interface methods * standard library (but most packages won't work due to missing language features) + * slices (partially) Not yet supported: * float, complex, etc. * maps - * slices * garbage collection * defer * closures diff --git a/compiler.go b/compiler.go index f35434f6..df103f75 100644 --- a/compiler.go +++ b/compiler.go @@ -37,7 +37,7 @@ type Compiler struct { intType llvm.Type i8ptrType llvm.Type // for convenience uintptrType llvm.Type - stringLenType llvm.Type + lenType llvm.Type allocFunc llvm.Value freeFunc llvm.Value coroIdFunc llvm.Value @@ -89,13 +89,13 @@ func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { // Depends on platform (32bit or 64bit), but fix it here for now. c.intType = llvm.Int32Type() - c.stringLenType = llvm.Int32Type() + c.lenType = llvm.Int32Type() c.uintptrType = c.targetData.IntPtrType() c.i8ptrType = llvm.PointerType(llvm.Int8Type(), 0) // Go string: tuple of (len, ptr) t := c.ctx.StructCreateNamed("string") - t.StructSetBody([]llvm.Type{c.stringLenType, c.i8ptrType}, false) + t.StructSetBody([]llvm.Type{c.lenType, c.i8ptrType}, false) allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false) c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType) @@ -485,6 +485,17 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } // make a function pointer of it return llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), 0), nil + case *types.Slice: + elemType, err := c.getLLVMType(typ.Elem()) + if err != nil { + return llvm.Type{}, err + } + members := []llvm.Type{ + llvm.PointerType(elemType, 0), + c.lenType, // len + c.lenType, // cap + } + return llvm.StructType(members, false), nil case *types.Struct: members := make([]llvm.Type, typ.NumFields()) for i := 0; i < typ.NumFields(); i++ { @@ -496,7 +507,7 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } return llvm.StructType(members, false), nil default: - return llvm.Type{}, errors.New("todo: unknown type: " + fmt.Sprintf("%#v", goType)) + return llvm.Type{}, errors.New("todo: unknown type: " + goType.String()) } } @@ -920,6 +931,35 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) (llvm.Value, error) { switch callName { + case "cap": + value, err := c.parseExpr(frame, args[0]) + if err != nil { + return llvm.Value{}, err + } + switch args[0].Type().(type) { + case *types.Slice: + return c.builder.CreateExtractValue(value, 2, "cap"), nil + default: + return llvm.Value{}, errors.New("todo: cap: unknown type") + } + case "len": + value, err := c.parseExpr(frame, args[0]) + if err != nil { + return llvm.Value{}, err + } + switch typ := args[0].Type().(type) { + case *types.Basic: + switch typ.Kind() { + case types.String: + return c.builder.CreateExtractValue(value, 0, "len"), nil + default: + return llvm.Value{}, errors.New("todo: len: unknown basic type") + } + case *types.Slice: + return c.builder.CreateExtractValue(value, 1, "len"), nil + default: + return llvm.Value{}, errors.New("todo: len: unknown type") + } case "print", "println": for i, arg := range args { if i >= 1 { @@ -974,22 +1014,6 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) c.builder.CreateCall(c.mod.NamedFunction("runtime.printnl"), nil, "") } return llvm.Value{}, nil // print() or println() returns void - case "len": - value, err := c.parseExpr(frame, args[0]) - if err != nil { - return llvm.Value{}, err - } - switch typ := args[0].Type().(type) { - case *types.Basic: - switch typ.Kind() { - case types.String: - return c.builder.CreateExtractValue(value, 0, "len"), nil - default: - return llvm.Value{}, errors.New("todo: len: unknown basic type") - } - default: - return llvm.Value{}, errors.New("todo: len: unknown type") - } case "ssa:wrapnilchk": // TODO: do an actual nil check? return c.parseExpr(frame, args[0]) @@ -1213,34 +1237,46 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, err } - // Get buffer length - var buflen llvm.Value - typ := expr.X.Type().(*types.Pointer).Elem() - switch typ := typ.(type) { - case *types.Array: - buflen = llvm.ConstInt(llvm.Int32Type(), uint64(typ.Len()), false) + // Get buffer pointer and length + var bufptr, buflen llvm.Value + switch ptrTyp := expr.X.Type().(type) { + case *types.Pointer: + typ := expr.X.Type().(*types.Pointer).Elem() + switch typ := typ.(type) { + case *types.Array: + bufptr = val + buflen = llvm.ConstInt(llvm.Int32Type(), uint64(typ.Len()), false) + default: + return llvm.Value{}, errors.New("todo: indexaddr: " + typ.String()) + } + case *types.Slice: + bufptr = c.builder.CreateExtractValue(val, 0, "indexaddr.ptr") + buflen = c.builder.CreateExtractValue(val, 1, "indexaddr.len") default: - return llvm.Value{}, errors.New("todo: indexaddr: len") + return llvm.Value{}, errors.New("todo: indexaddr: " + ptrTyp.String()) } // Bounds check. // LLVM optimizes this away in most cases. - // TODO: runtime.boundsCheck is undefined in packages imported by + // TODO: runtime.lookupBoundsCheck is undefined in packages imported by // package runtime, so we have to remove it. This should be fixed. - boundsCheck := c.mod.NamedFunction("runtime.boundsCheck") - if !boundsCheck.IsNil() { - constZero := llvm.ConstInt(c.intType, 0, false) - isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 - isTooBig := c.builder.CreateICmp(llvm.IntSGE, index, buflen, "") // index >= len(value) - isOverflow := c.builder.CreateOr(isNegative, isTooBig, "") - c.builder.CreateCall(boundsCheck, []llvm.Value{isOverflow}, "") + lookupBoundsCheck := c.mod.NamedFunction("runtime.lookupBoundsCheck") + if !lookupBoundsCheck.IsNil() { + c.builder.CreateCall(lookupBoundsCheck, []llvm.Value{buflen, index}, "") } - indices := []llvm.Value{ - llvm.ConstInt(llvm.Int32Type(), 0, false), - index, + switch expr.X.Type().(type) { + case *types.Pointer: + indices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + index, + } + return c.builder.CreateGEP(bufptr, indices, ""), nil + case *types.Slice: + return c.builder.CreateGEP(bufptr, []llvm.Value{index}, ""), nil + default: + panic("unreachable") } - return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Lookup: if expr.CommaOk { return llvm.Value{}, errors.New("todo: lookup with comma-ok") @@ -1263,16 +1299,12 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Bounds check. // LLVM optimizes this away in most cases. - if frame.fn.llvmFn.Name() != "runtime.boundsCheck" { - constZero := llvm.ConstInt(c.intType, 0, false) - isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 - strlen, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") + if frame.fn.llvmFn.Name() != "runtime.lookupBoundsCheck" { + length, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") if err != nil { return llvm.Value{}, err // shouldn't happen } - isTooBig := c.builder.CreateICmp(llvm.IntSGE, index, strlen, "") // index >= len(value) - isOverflow := c.builder.CreateOr(isNegative, isTooBig, "") - c.builder.CreateCall(c.mod.NamedFunction("runtime.boundsCheck"), []llvm.Value{isOverflow}, "") + c.builder.CreateCall(c.mod.NamedFunction("runtime.lookupBoundsCheck"), []llvm.Value{length, index}, "") } // Lookup byte @@ -1324,6 +1356,73 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { phi := c.builder.CreatePHI(t, "") frame.phis = append(frame.phis, Phi{expr, phi}) return phi, nil + case *ssa.Slice: + if expr.Max != nil { + return llvm.Value{}, errors.New("todo: full slice expressions (with max): " + expr.Type().String()) + } + value, err := c.parseExpr(frame, expr.X) + if err != nil { + return llvm.Value{}, err + } + switch typ := expr.X.Type().Underlying().(type) { + case *types.Pointer: // pointer to array + // slice an array + length := typ.Elem().(*types.Array).Len() + llvmLen := llvm.ConstInt(c.lenType, uint64(length), false) + var low, high llvm.Value + if expr.Low == nil { + low = llvm.ConstInt(c.lenType, 0, false) + } else { + low, err = c.parseExpr(frame, expr.Low) + if err != nil { + return llvm.Value{}, nil + } + } + if expr.High == nil { + high = llvmLen + } else { + high, err = c.parseExpr(frame, expr.High) + if err != nil { + return llvm.Value{}, nil + } + } + indices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + low, + } + slicePtr := c.builder.CreateGEP(value, indices, "slice.ptr") + sliceLen := c.builder.CreateSub(high, low, "slice.len") + sliceCap := c.builder.CreateSub(llvmLen, low, "slice.cap") + sliceTyp, err := c.getLLVMType(expr.Type()) + if err != nil { + return llvm.Value{}, err + } + + // This check is optimized away in most cases. + sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") + c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{llvmLen, low, high}, "") + + slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ + llvm.Undef(slicePtr.Type()), + llvm.Undef(c.lenType), + llvm.Undef(c.lenType), + }) + slice = c.builder.CreateInsertValue(slice, slicePtr, 0, "") + slice = c.builder.CreateInsertValue(slice, sliceLen, 1, "") + slice = c.builder.CreateInsertValue(slice, sliceCap, 2, "") + return slice, nil + case *types.Slice: + // slice a slice + return llvm.Value{}, errors.New("todo: slice a slice: " + typ.String()) + case *types.Basic: + // slice a string + if typ.Kind() != types.String { + return llvm.Value{}, errors.New("unknown slice type: " + typ.String()) + } + return llvm.Value{}, errors.New("todo: slice a string") + default: + return llvm.Value{}, errors.New("unknown slice type: " + typ.String()) + } case *ssa.TypeAssert: if !expr.CommaOk { return llvm.Value{}, errors.New("todo: type assert without comma-ok") @@ -1377,7 +1476,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { case *ssa.UnOp: return c.parseUnOp(frame, expr) default: - return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) + return llvm.Value{}, errors.New("todo: unknown expression: " + expr.String()) } } @@ -1485,7 +1584,7 @@ func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) { return llvm.ConstInt(llvmType, n, false), nil } else if typ.Kind() == types.String { str := constant.StringVal(expr.Value) - strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) + strLen := llvm.ConstInt(c.lenType, uint64(len(str)), false) global := llvm.AddGlobal(c.mod, llvm.ArrayType(llvm.Int8Type(), len(str)), ".str") global.SetInitializer(c.ctx.ConstString(str, false)) global.SetLinkage(llvm.PrivateLinkage) diff --git a/src/examples/hello/hello.go b/src/examples/hello/hello.go index f8389da7..a54b0fbc 100644 --- a/src/examples/hello/hello.go +++ b/src/examples/hello/hello.go @@ -23,6 +23,11 @@ func main() { println("sumrange(100) =", sumrange(100)) println("strlen foo:", strlen("foo")) + foo := []int{1, 2, 4, 5} + println("len/cap foo:", len(foo), cap(foo)) + println("foo[3]:", foo[3]) + println("sum foo:", sum(foo)) + thing := &Thing{"foo"} println("thing:", thing.String()) printItf(5) @@ -93,3 +98,11 @@ func sumrange(n int) int { } return sum } + +func sum(l []int) int { + sum := 0 + for _, n := range l { + sum += n + } + return sum +} diff --git a/src/runtime/runtime.go b/src/runtime/runtime.go index 228b1ede..fad61fed 100644 --- a/src/runtime/runtime.go +++ b/src/runtime/runtime.go @@ -19,11 +19,20 @@ func _panic(message interface{}) { abort() } -func boundsCheck(outOfRange bool) { - if outOfRange { +// Check for bounds in *ssa.IndexAddr and *ssa.Lookup. +func lookupBoundsCheck(length, index int) { + if index < 0 || index >= length { // printstring() here is safe as this function is excluded from bounds // checking. printstring("panic: runtime error: index out of range\n") abort() } } + +// Check for bounds in *ssa.Slice +func sliceBoundsCheck(length, low, high uint) { + if !(0 <= low && low <= high && high <= length) { + printstring("panic: runtime error: slice out of range\n") + abort() + } +}