From c248418dbec8c31d81fd8f8fedcdbd1fb93a85f1 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Wed, 27 May 2020 00:17:28 +0200 Subject: [PATCH] compiler: fix a few crashes due to named types There were a few cases left where a named type would cause a crash in the compiler. While going through enough code would have found them eventually, I specifically looked for the `Type().(` pattern: a Type() call that is then used in a type assert. Most of those were indeed bugs, although for some I couldn't come up with a reproducer so I left them as-is. --- compiler/channel.go | 6 +++--- compiler/compiler.go | 8 ++++---- testdata/channel.go | 11 +++++++++++ testdata/coroutines.go | 12 ++++++++++++ testdata/slice.go | 5 +++++ 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/compiler/channel.go b/compiler/channel.go index 3686c98f..a2532e48 100644 --- a/compiler/channel.go +++ b/compiler/channel.go @@ -12,7 +12,7 @@ import ( ) func (b *builder) createMakeChan(expr *ssa.MakeChan) llvm.Value { - elementSize := b.targetData.TypeAllocSize(b.getLLVMType(expr.Type().(*types.Chan).Elem())) + elementSize := b.targetData.TypeAllocSize(b.getLLVMType(expr.Type().Underlying().(*types.Chan).Elem())) elementSizeValue := llvm.ConstInt(b.uintptrType, elementSize, false) bufSize := b.getValue(expr.Size) b.createChanBoundsCheck(elementSize, bufSize, expr.Size.Type().Underlying().(*types.Basic), expr.Pos()) @@ -47,7 +47,7 @@ func (b *builder) createChanSend(instr *ssa.Send) { // createChanRecv emits a pseudo chan receive operation. It is lowered to the // actual channel receive operation during goroutine lowering. func (b *builder) createChanRecv(unop *ssa.UnOp) llvm.Value { - valueType := b.getLLVMType(unop.X.Type().(*types.Chan).Elem()) + valueType := b.getLLVMType(unop.X.Type().Underlying().(*types.Chan).Elem()) ch := b.getValue(unop.X) // Allocate memory to receive into. @@ -117,7 +117,7 @@ func (b *builder) createSelect(expr *ssa.Select) llvm.Value { switch state.Dir { case types.RecvOnly: // Make sure the receive buffer is big enough and has the correct alignment. - llvmType := b.getLLVMType(state.Chan.Type().(*types.Chan).Elem()) + llvmType := b.getLLVMType(state.Chan.Type().Underlying().(*types.Chan).Elem()) if size := b.targetData.TypeAllocSize(llvmType); size > recvbufSize { recvbufSize = size } diff --git a/compiler/compiler.go b/compiler/compiler.go index 37858ea3..53cafed5 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1053,7 +1053,7 @@ func (b *builder) createInstruction(instr ssa.Instruction) { // goroutine: // * The function context, for closures. // * The function pointer (for tasks). - funcPtr, context := b.decodeFuncValue(b.getValue(instr.Call.Value), instr.Call.Value.Type().(*types.Signature)) + funcPtr, context := b.decodeFuncValue(b.getValue(instr.Call.Value), instr.Call.Value.Type().Underlying().(*types.Signature)) params = append(params, context) // context parameter switch b.Scheduler() { case "none", "coroutines": @@ -1516,7 +1516,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { index := b.getValue(expr.Index) // Check bounds. - arrayLen := expr.X.Type().(*types.Array).Len() + arrayLen := expr.X.Type().Underlying().(*types.Array).Len() arrayLenLLVM := llvm.ConstInt(b.uintptrType, uint64(arrayLen), false) b.createLookupBoundsCheck(arrayLenLLVM, index, expr.Index.Type()) @@ -1628,8 +1628,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { } // Bounds checking. - lenType := expr.Len.Type().(*types.Basic) - capType := expr.Cap.Type().(*types.Basic) + lenType := expr.Len.Type().Underlying().(*types.Basic) + capType := expr.Cap.Type().Underlying().(*types.Basic) b.createSliceBoundsCheck(maxSize, sliceLen, sliceCap, sliceCap, lenType, capType, capType) // Allocate the backing array. diff --git a/testdata/channel.go b/testdata/channel.go index e1acea7e..6a7945e5 100644 --- a/testdata/channel.go +++ b/testdata/channel.go @@ -8,6 +8,8 @@ import ( var wg sync.WaitGroup +type intchan chan int + func main() { ch := make(chan int, 2) ch <- 1 @@ -40,6 +42,15 @@ func main() { _ = make(chan int, uint32(2)) _ = make(chan int, uint64(2)) + // Test that named channels don't crash the compiler. + named := make(intchan, 1) + named <- 3 + <-named + select { + case <-named: + default: + } + // Test bigger values ch2 := make(chan complex128) wg.Add(1) diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 49fdfc28..bb8acdbc 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -68,6 +68,8 @@ func main() { m.Unlock() println("done") + startSimpleFunc(emptyFunc) + time.Sleep(2 * time.Millisecond) } @@ -100,6 +102,11 @@ func sleepFuncValue(fn func(int)) { go fn(8) } +func startSimpleFunc(fn simpleFunc) { + // Test that named function types don't crash the compiler. + go fn() +} + func nowait() { println("non-blocking goroutine") } @@ -115,3 +122,8 @@ func (i *myPrinter) Print() { time.Sleep(time.Millisecond) println("async interface method call") } + +type simpleFunc func() + +func emptyFunc() { +} diff --git a/testdata/slice.go b/testdata/slice.go index d20de76d..b5e54367 100644 --- a/testdata/slice.go +++ b/testdata/slice.go @@ -31,6 +31,7 @@ func main() { assert(len(make([]int, makeUint32(2), makeUint32(3))) == 2) assert(len(make([]int, makeUint64(2), makeUint64(3))) == 2) assert(len(make([]int, makeUintptr(2), makeUintptr(3))) == 2) + assert(len(make([]int, makeMyUint8(2), makeMyUint8(3))) == 2) // indexing into a slice with uncommon index types assert(foo[int(2)] == 4) @@ -131,6 +132,9 @@ func main() { var named MySlice assert(len(unnamed[:]) == 32) assert(len(named[:]) == 32) + for _, c := range named { + assert(c == 0) + } } func printslice(name string, s []int) { @@ -169,3 +173,4 @@ func makeUint16(x uint16) uint16 { return x } func makeUint32(x uint32) uint32 { return x } func makeUint64(x uint64) uint64 { return x } func makeUintptr(x uintptr) uintptr { return x } +func makeMyUint8(x myUint8) myUint8 { return x }