From 3a6ef380414f31c6634daa8d78ec8c135a293c4b Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Wed, 22 Aug 2018 04:50:24 +0200 Subject: [PATCH] Preliminary implementation of a hashmap, unfinished Missing features: * keys other than strings * more than 8 values in the hashmap * growing a map when needed * initial size hint * delete(m, key) * iterators (for range) * initializing global maps * ...more? --- README.markdown | 2 +- compiler.go | 151 +++++++++++++++++++++++++++++++----- src/examples/hello/hello.go | 7 ++ src/runtime/hashmap.go | 133 +++++++++++++++++++++++++++++++ src/runtime/runtime.go | 18 +++++ 5 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 src/runtime/hashmap.go diff --git a/README.markdown b/README.markdown index cf409f95..9b01bfaf 100644 --- a/README.markdown +++ b/README.markdown @@ -46,11 +46,11 @@ Currently supported features: * standard library (but most packages won't work due to missing language features) * slices (partially) + * maps (very rough, unfinished) Not yet supported: * float, complex, etc. - * maps * garbage collection * defer * closures diff --git a/compiler.go b/compiler.go index 2f9f3165..5b29b1e7 100644 --- a/compiler.go +++ b/compiler.go @@ -431,6 +431,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } case *types.Interface: return c.mod.GetTypeByName("interface"), nil + case *types.Map: + return llvm.PointerType(c.mod.GetTypeByName("runtime.hashmap"), 0), nil case *types.Named: if _, ok := typ.Underlying().(*types.Struct); ok { llvmType := c.mod.GetTypeByName(typ.Obj().Pkg().Path() + "." + typ.Obj().Name()) @@ -878,6 +880,36 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { blockJump := frame.blocks[instr.Block().Succs[0]] c.builder.CreateBr(blockJump) return nil + case *ssa.MapUpdate: + m, err := c.parseExpr(frame, instr.Map) + if err != nil { + return err + } + key, err := c.parseExpr(frame, instr.Key) + if err != nil { + return err + } + value, err := c.parseExpr(frame, instr.Value) + if err != nil { + return err + } + mapType := instr.Map.Type().Underlying().(*types.Map) + switch keyType := mapType.Key().Underlying().(type) { + case *types.Basic: + if keyType.Kind() == types.String { + valueAlloca := c.builder.CreateAlloca(value.Type(), "hashmap.value") + c.builder.CreateStore(value, valueAlloca) + valuePtr := c.builder.CreateBitCast(valueAlloca, c.i8ptrType, "hashmap.valueptr") + params := []llvm.Value{m, key, valuePtr} + fn := c.mod.NamedFunction("runtime.hashmapSet") + c.builder.CreateCall(fn, params, "") + return nil + } else { + return errors.New("todo: map update key type: " + keyType.String()) + } + default: + return errors.New("todo: map update key type: " + keyType.String()) + } case *ssa.Panic: value, err := c.parseExpr(frame, instr.X) if err != nil { @@ -1292,13 +1324,6 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if expr.CommaOk { return llvm.Value{}, errors.New("todo: lookup with comma-ok") } - if _, ok := expr.X.Type().(*types.Map); ok { - return llvm.Value{}, errors.New("todo: lookup in map") - } - // Value type must be a string, which is a basic type. - if expr.X.Type().(*types.Basic).Kind() != types.String { - panic("lookup on non-string?") - } value, err := c.parseExpr(frame, expr.X) if err != nil { return llvm.Value{}, nil @@ -1307,21 +1332,50 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, nil } - - // Bounds check. - // LLVM optimizes this away in most cases. - if frame.fn.llvmFn.Name() != "runtime.lookupBoundsCheck" { - length, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") - if err != nil { - return llvm.Value{}, err // shouldn't happen + switch xType := expr.X.Type().(type) { + case *types.Basic: + // Value type must be a string, which is a basic type. + if xType.Kind() != types.String { + panic("lookup on non-string?") } - c.builder.CreateCall(c.mod.NamedFunction("runtime.lookupBoundsCheck"), []llvm.Value{length, index}, "") - } - // Lookup byte - buf := c.builder.CreateExtractValue(value, 1, "") - bufPtr := c.builder.CreateGEP(buf, []llvm.Value{index}, "") - return c.builder.CreateLoad(bufPtr, ""), nil + // Bounds check. + // LLVM optimizes this away in most cases. + if frame.fn.llvmFn.Name() != "runtime.lookupBoundsCheck" { + length, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") + if err != nil { + return llvm.Value{}, err // shouldn't happen + } + c.builder.CreateCall(c.mod.NamedFunction("runtime.lookupBoundsCheck"), []llvm.Value{length, index}, "") + } + + // Lookup byte + buf := c.builder.CreateExtractValue(value, 1, "") + bufPtr := c.builder.CreateGEP(buf, []llvm.Value{index}, "") + return c.builder.CreateLoad(bufPtr, ""), nil + case *types.Map: + switch keyType := xType.Key().Underlying().(type) { + case *types.Basic: + if keyType.Kind() == types.String { + llvmValueType, err := c.getLLVMType(expr.Type()) + if err != nil { + return llvm.Value{}, err + } + mapValueAlloca := c.builder.CreateAlloca(llvmValueType, "hashmap.value") + mapValuePtr := c.builder.CreateBitCast(mapValueAlloca, c.i8ptrType, "hashmap.valueptr") + params := []llvm.Value{value, index, mapValuePtr} + fn := c.mod.NamedFunction("runtime.hashmapGet") + c.builder.CreateCall(fn, params, "") + return c.builder.CreateLoad(mapValueAlloca, ""), nil + } else { + return llvm.Value{}, errors.New("todo: map lookup key type: " + keyType.String()) + } + default: + return llvm.Value{}, errors.New("todo: map lookup key type: " + keyType.String()) + } + default: + panic("unknown lookup type: " + expr.String()) + } case *ssa.MakeInterface: val, err := c.parseExpr(frame, expr.X) if err != nil { @@ -1359,6 +1413,63 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{llvm.ConstInt(llvm.Int32Type(), uint64(itfTypeNum), false), llvm.Undef(c.i8ptrType)}) itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil + case *ssa.MakeMap: + mapType := expr.Type().Underlying().(*types.Map) + llvmKeyType, err := c.getLLVMType(mapType.Key().Underlying()) + if err != nil { + return llvm.Value{}, err + } + llvmValueType, err := c.getLLVMType(mapType.Elem().Underlying()) + if err != nil { + return llvm.Value{}, err + } + switch keyType := mapType.Key().Underlying().(type) { + case *types.Basic: + if keyType.Kind() == types.String { + // Create hashmap + llvmType := c.mod.GetTypeByName("runtime.hashmap") + size := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(llvmType), false) + buf := c.builder.CreateCall(c.allocFunc, []llvm.Value{size}, "") + buf = c.builder.CreateBitCast(buf, llvm.PointerType(llvmType, 0), "") + + // Set keySize + keySize := c.targetData.TypeAllocSize(llvmKeyType) + keyIndices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 3, false), // keySize uint8 + } + keySizePtr := c.builder.CreateGEP(buf, keyIndices, "hashmap.keySize") + c.builder.CreateStore(llvm.ConstInt(llvm.Int8Type(), keySize, false), keySizePtr) + + // Set valueSize + valueSize := c.targetData.TypeAllocSize(llvmValueType) + valueIndices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 4, false), // valueSize uint8 + } + valueSizePtr := c.builder.CreateGEP(buf, valueIndices, "hashmap.valueSize") + c.builder.CreateStore(llvm.ConstInt(llvm.Int8Type(), valueSize, false), valueSizePtr) + + // Create initial bucket + bucketType := c.mod.GetTypeByName("runtime.hashmapBucket") + bucketSize := c.targetData.TypeAllocSize(bucketType) + keySize*8 + valueSize*8 + bucketSizeValue := llvm.ConstInt(c.uintptrType, bucketSize, false) + bucket := c.builder.CreateCall(c.allocFunc, []llvm.Value{bucketSizeValue}, "") + + // Set initial bucket + bucketIndices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 1, false), // buckets unsafe.Pointer + } + bucketsElementPtr := c.builder.CreateGEP(buf, bucketIndices, "hashmap.buckets") + c.builder.CreateStore(bucket, bucketsElementPtr) + return buf, nil + } else { + return llvm.Value{}, errors.New("todo: map key type: " + keyType.String()) + } + default: + return llvm.Value{}, errors.New("todo: map key type: " + keyType.String()) + } case *ssa.Phi: t, err := c.getLLVMType(expr.Type()) if err != nil { diff --git a/src/examples/hello/hello.go b/src/examples/hello/hello.go index a54b0fbc..bc85296d 100644 --- a/src/examples/hello/hello.go +++ b/src/examples/hello/hello.go @@ -23,6 +23,9 @@ func main() { println("sumrange(100) =", sumrange(100)) println("strlen foo:", strlen("foo")) + m := map[string]int{"answer": 42, "foo": 3} + readMap(m, "answer") + foo := []int{1, 2, 4, 5} println("len/cap foo:", len(foo), cap(foo)) println("foo[3]:", foo[3]) @@ -46,6 +49,10 @@ func runFunc(f func(int), arg int) { f(arg) } +func readMap(m map[string]int, key string) { + println("map read:", key, "=", m[key]) +} + func hello(n int) { println("hello from function pointer:", n) } diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go new file mode 100644 index 00000000..ce8b356e --- /dev/null +++ b/src/runtime/hashmap.go @@ -0,0 +1,133 @@ +package runtime + +// This is a hashmap implementation for the map[T]T type. +// It is very rougly based on the implementation of the Go hashmap: +// +// https://golang.org/src/runtime/hashmap.go + +import ( + "unsafe" +) + +// The underlying hashmap structure for Go. +type hashmap struct { + next *hashmap // hashmap after evacuate (for iterators) + buckets unsafe.Pointer // pointer to array of buckets + count uint + keySize uint8 // maybe this can store the key type as well? E.g. keysize == 5 means string? + valueSize uint8 + bucketBits uint8 +} + +// A hashmap bucket. A bucket is a container of 8 key/value pairs: first the +// following two entries, then the 8 keys, then the 8 values. This somewhat odd +// ordering is to make sure the keys and values are well aligned when one of +// them is smaller than the system word size. +type hashmapBucket struct { + tophash [8]uint8 + next *hashmapBucket // next bucket (if there are more than 8 in a chain) + // Followed by the actual keys, and then the actual values. These are + // allocated but as they're of variable size they can't be shown here. +} + +// Get FNV-1a hash of this string. +// +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function#FNV-1a_hash +func stringhash(s *string) uint32 { + var result uint32 = 2166136261 // FNV offset basis + for i := 0; i < len(*s); i++ { + result ^= uint32((*s)[i]) + result *= 16777619 // FNV prime + } + return result +} + +// Set a specified key to a given value. Grow the map if necessary. +func hashmapSet(m *hashmap, key string, value unsafe.Pointer) { + hash := stringhash(&key) + numBuckets := uintptr(1) << m.bucketBits + bucketNumber := (uintptr(hash) & (numBuckets - 1)) + bucketSize := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*8 + bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber + bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + + tophash := uint8(hash >> 24) + if tophash < 1 { + // 0 means empty slot, so make it bigger. + tophash += 1 + } + + // See whether the key already exists somewhere. + var emptySlotKey *string + var emptySlotValue unsafe.Pointer + var emptySlotTophash *byte + for bucket != nil { + for i := uintptr(0); i < 8; i++ { + slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*uintptr(i) + slotKey := (*string)(unsafe.Pointer(bucketAddr + slotKeyOffset)) + slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*uintptr(i) + slotValue := unsafe.Pointer(bucketAddr + slotValueOffset) + if bucket.tophash[i] == 0 && emptySlotKey == nil { + // Found an empty slot, store it for if we couldn't find an + // existing slot. + emptySlotKey = slotKey + emptySlotValue = slotValue + emptySlotTophash = &bucket.tophash[i] + } + if bucket.tophash[i] == tophash { + // Could be an existing value that's the same. + if key == *slotKey { + // found same key, replace it + memcpy(slotValue, value, uintptr(m.valueSize)) + return + } + } + } + bucket = bucket.next + } + if emptySlotKey != nil { + *emptySlotKey = key + memcpy(emptySlotValue, value, uintptr(m.valueSize)) + *emptySlotTophash = tophash + return + } + panic("todo: hashmap: grow bucket") +} + +// Get the value of a specified key, or zero the value if not found. +func hashmapGet(m *hashmap, key string, value unsafe.Pointer) { + hash := stringhash(&key) + numBuckets := uintptr(1) << m.bucketBits + bucketNumber := (uintptr(hash) & (numBuckets - 1)) + bucketSize := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*8 + bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber + bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + + tophash := uint8(hash >> 24) + if tophash < 1 { + // 0 means empty slot, so make it bigger. + tophash += 1 + } + + // Try to find the key. + for bucket != nil { + for i := uintptr(0); i < 8; i++ { + slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*uintptr(i) + slotKey := (*string)(unsafe.Pointer(bucketAddr + slotKeyOffset)) + slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + uintptr(m.keySize)*8 + uintptr(m.valueSize)*uintptr(i) + slotValue := unsafe.Pointer(bucketAddr + slotValueOffset) + if bucket.tophash[i] == tophash { + // This could be the key we're looking for. + if key == *slotKey { + // Found the key, copy it. + memcpy(value, slotValue, uintptr(m.valueSize)) + return + } + } + } + bucket = bucket.next + } + + // Did not find the key. + memzero(value, uintptr(m.valueSize)) +} diff --git a/src/runtime/runtime.go b/src/runtime/runtime.go index 1f85963c..386bd847 100644 --- a/src/runtime/runtime.go +++ b/src/runtime/runtime.go @@ -1,5 +1,9 @@ package runtime +import ( + "unsafe" +) + const Compiler = "tgo" // The bitness of the CPU (e.g. 8, 32, 64). Set by the compiler as a constant. @@ -29,6 +33,20 @@ func stringequal(x, y string) bool { return true } +// Copy size bytes from src to dst. The memory areas must not overlap. +func memcpy(dst, src unsafe.Pointer, size uintptr) { + for i := uintptr(0); i < size; i++ { + *(*uint8)(unsafe.Pointer(uintptr(dst) + i)) = *(*uint8)(unsafe.Pointer(uintptr(src) + i)) + } +} + +// Set the given number of bytes to zero. +func memzero(ptr unsafe.Pointer, size uintptr) { + for i := uintptr(0); i < size; i++ { + *(*byte)(unsafe.Pointer(uintptr(ptr) + size)) = 0 + } +} + func _panic(message interface{}) { printstring("panic: ") printitf(message)