From 1f0651c84c7f871d5de802ebc036176ff1b67753 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sat, 21 Apr 2018 00:26:45 +0200 Subject: [PATCH] Implement string out of bounds checks --- src/runtime/runtime.go | 9 +++++++++ tgo.go | 41 +++++++++++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/runtime/runtime.go b/src/runtime/runtime.go index 944325e7..72c6a206 100644 --- a/src/runtime/runtime.go +++ b/src/runtime/runtime.go @@ -56,3 +56,12 @@ func _panic(message interface{}) { printnl() C.exit(1) } + +func boundsCheck(s string, outOfRange bool) { + if outOfRange { + // printstring() here is safe as this function is excluded from bounds + // checking. + printstring("panic: runtime error: index out of range\n") + C.exit(1) + } +} diff --git a/tgo.go b/tgo.go index c2a615e2..7c917f01 100644 --- a/tgo.go +++ b/tgo.go @@ -44,6 +44,7 @@ type Compiler struct { interfaceType llvm.Type typeassertType llvm.Type panicFunc llvm.Value + boundsCheckFunc llvm.Value printstringFunc llvm.Value printintFunc llvm.Value printbyteFunc llvm.Value @@ -101,6 +102,9 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) { panicType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.interfaceType}, false) c.panicFunc = llvm.AddFunction(c.mod, "runtime._panic", panicType) + boundsCheckType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType, llvm.Int1Type()}, false) + c.boundsCheckFunc = llvm.AddFunction(c.mod, "runtime.boundsCheck", boundsCheckType) + printstringType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.stringType}, false) c.printstringFunc = llvm.AddFunction(c.mod, "runtime.printstring", printstringType) printintType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.intType}, false) @@ -476,13 +480,12 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { } } -func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { - fmt.Printf(" builtin: %v\n", call) - name := call.Name() +func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) (llvm.Value, error) { + fmt.Printf(" builtin: %v\n", callName) - switch name { + switch callName { case "print", "println": - for i, arg := range instr.Args { + for i, arg := range args { if i >= 1 { c.builder.CreateCall(c.printspaceFunc, nil, "") } @@ -507,17 +510,16 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B return llvm.Value{}, errors.New("unknown arg type: " + fmt.Sprintf("%#v", typ)) } } - if name == "println" { + if callName == "println" { c.builder.CreateCall(c.printnlFunc, nil, "") } return llvm.Value{}, nil // print() or println() returns void case "len": - arg := instr.Args[0] - value, err := c.parseExpr(frame, arg) + value, err := c.parseExpr(frame, args[0]) if err != nil { return llvm.Value{}, err } - switch typ := arg.Type().(type) { + switch typ := args[0].Type().(type) { case *types.Basic: switch typ.Kind() { case types.String: @@ -529,7 +531,7 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B return llvm.Value{}, errors.New("todo: len: unknown type") } default: - return llvm.Value{}, errors.New("todo: builtin: " + name) + return llvm.Value{}, errors.New("todo: builtin: " + callName) } } @@ -564,7 +566,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error) switch call := instr.Common().Value.(type) { case *ssa.Builtin: - return c.parseBuiltin(frame, instr.Common(), call) + return c.parseBuiltin(frame, instr.Common().Args, call.Name()) case *ssa.Function: return c.parseFunctionCall(frame, instr.Common(), call) default: @@ -682,7 +684,22 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, nil } - // TODO: out-of-bounds checking + + // Bounds check + // TODO: inline, and avoid if possible + if frame.llvmFn.Name() != "runtime.boundsCheck" { + constZero := llvm.ConstInt(c.intType, 0, false) + isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 + strlen, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") + if err != nil { + return llvm.Value{}, err // shouldn't happen + } + isTooBig := c.builder.CreateICmp(llvm.IntSGE, index, strlen, "") // index >= len(value) + isOverflow := c.builder.CreateOr(isNegative, isTooBig, "") + c.builder.CreateCall(c.boundsCheckFunc, []llvm.Value{value, isOverflow}, "") + } + + // Lookup byte buf := c.builder.CreateExtractValue(value, 1, "") bufPtr := c.builder.CreateGEP(buf, []llvm.Value{index}, "") return c.builder.CreateLoad(bufPtr, ""), nil