diff --git a/compiler/map.go b/compiler/map.go index e78bef2a..9a5f23fa 100644 --- a/compiler/map.go +++ b/compiler/map.go @@ -14,8 +14,20 @@ import ( // initializing an appropriately sized object. func (c *Compiler) emitMakeMap(frame *Frame, expr *ssa.MakeMap) (llvm.Value, error) { mapType := expr.Type().Underlying().(*types.Map) - llvmKeyType := c.getLLVMType(mapType.Key().Underlying()) + keyType := mapType.Key().Underlying() llvmValueType := c.getLLVMType(mapType.Elem().Underlying()) + var llvmKeyType llvm.Type + if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { + // String keys. + llvmKeyType = c.getLLVMType(keyType) + } else if hashmapIsBinaryKey(keyType) { + // Trivially comparable keys. + llvmKeyType = c.getLLVMType(keyType) + } else { + // All other keys. Implemented as map[interface{}]valueType for ease of + // implementation. + llvmKeyType = c.getLLVMRuntimeType("_interface") + } keySize := c.targetData.TypeAllocSize(llvmKeyType) valueSize := c.targetData.TypeAllocSize(llvmValueType) llvmKeySize := llvm.ConstInt(c.ctx.Int8Type(), keySize, false) @@ -43,6 +55,7 @@ func (c *Compiler) emitMapLookup(keyType, valueType types.Type, m, key llvm.Valu // Do the lookup. How it is done depends on the key type. var commaOkValue llvm.Value + keyType = keyType.Underlying() if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string params := []llvm.Value{m, key, mapValuePtr} @@ -58,8 +71,14 @@ func (c *Compiler) emitMapLookup(keyType, valueType types.Type, m, key llvm.Valu commaOkValue = c.createRuntimeCall("hashmapBinaryGet", params, "") c.emitLifetimeEnd(mapKeyPtr, mapKeySize) } else { - // Not trivially comparable using memcmp. - return llvm.Value{}, c.makeError(pos, "only strings, bools, ints, pointers or structs of bools/ints are supported as map keys, but got: "+keyType.String()) + // Not trivially comparable using memcmp. Make it an interface instead. + itfKey := key + if _, ok := keyType.(*types.Interface); !ok { + // Not already an interface, so convert it to an interface now. + itfKey = c.parseMakeInterface(key, keyType, pos) + } + params := []llvm.Value{m, itfKey, mapValuePtr} + commaOkValue = c.createRuntimeCall("hashmapInterfaceGet", params, "") } // Load the resulting value from the hashmap. The value is set to the zero @@ -93,7 +112,14 @@ func (c *Compiler) emitMapUpdate(keyType types.Type, m, key, value llvm.Value, p c.createRuntimeCall("hashmapBinarySet", params, "") c.emitLifetimeEnd(keyPtr, keySize) } else { - c.addError(pos, "only strings, bools, ints, pointers or structs of bools/ints are supported as map keys, but got: "+keyType.String()) + // Key is not trivially comparable, so compare it as an interface instead. + itfKey := key + if _, ok := keyType.(*types.Interface); !ok { + // Not already an interface, so convert it to an interface first. + itfKey = c.parseMakeInterface(key, keyType, pos) + } + params := []llvm.Value{m, itfKey, valuePtr} + c.createRuntimeCall("hashmapInterfaceSet", params, "") } c.emitLifetimeEnd(valuePtr, valueSize) } @@ -113,7 +139,16 @@ func (c *Compiler) emitMapDelete(keyType types.Type, m, key llvm.Value, pos toke c.emitLifetimeEnd(keyPtr, keySize) return nil } else { - return c.makeError(pos, "only strings, bools, ints, pointers or structs of bools/ints are supported as map keys, but got: "+keyType.String()) + // Key is not trivially comparable, so compare it as an interface + // instead. + itfKey := key + if _, ok := keyType.(*types.Interface); !ok { + // Not already an interface, so convert it to an interface first. + itfKey = c.parseMakeInterface(key, keyType, pos) + } + params := []llvm.Value{m, itfKey} + c.createRuntimeCall("hashmapInterfaceDelete", params, "") + return nil } } diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go index 2cccebcd..f9e164c4 100644 --- a/src/runtime/hashmap.go +++ b/src/runtime/hashmap.go @@ -6,6 +6,7 @@ package runtime // https://golang.org/src/runtime/map.go import ( + "reflect" "unsafe" ) @@ -318,3 +319,74 @@ func hashmapStringDelete(m *hashmap, key string) { hash := hashmapStringHash(key) hashmapDelete(m, unsafe.Pointer(&key), hash, hashmapStringEqual) } + +// Hashmap with interface keys (for everything else). + +func hashmapInterfaceHash(itf interface{}) uint32 { + x := reflect.ValueOf(itf) + if x.Type() == 0 { + return 0 // nil interface + } + + value := (*_interface)(unsafe.Pointer(&itf)).value + ptr := value + if x.Type().Size() <= unsafe.Sizeof(uintptr(0)) { + // Value fits in pointer, so it's directly stored in the pointer. + ptr = unsafe.Pointer(&value) + } + + switch x.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return hashmapHash(ptr, x.Type().Size()) + case reflect.Bool, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return hashmapHash(ptr, x.Type().Size()) + case reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + // It should be possible to just has the contents. However, NaN != NaN + // so if you're using lots of NaNs as map keys (you shouldn't) then hash + // time may become exponential. To fix that, it would be better to + // return a random number instead: + // https://research.swtch.com/randhash + return hashmapHash(ptr, x.Type().Size()) + case reflect.String: + return hashmapStringHash(x.String()) + case reflect.Chan, reflect.Ptr, reflect.UnsafePointer: + // It might seem better to just return the pointer, but that won't + // result in an evenly distributed hashmap. Instead, hash the pointer + // like most other types. + return hashmapHash(ptr, x.Type().Size()) + case reflect.Array: + var hash uint32 + for i := 0; i < x.Len(); i++ { + hash |= hashmapInterfaceHash(x.Index(i).Interface()) + } + return hash + case reflect.Struct: + var hash uint32 + for i := 0; i < x.NumField(); i++ { + hash |= hashmapInterfaceHash(x.Field(i).Interface()) + } + return hash + default: + runtimePanic("comparing un-comparable type") + return 0 // unreachable + } +} + +func hashmapInterfaceEqual(x, y unsafe.Pointer, n uintptr) bool { + return *(*interface{})(x) == *(*interface{})(y) +} + +func hashmapInterfaceSet(m *hashmap, key interface{}, value unsafe.Pointer) { + hash := hashmapInterfaceHash(key) + hashmapSet(m, unsafe.Pointer(&key), value, hash, hashmapInterfaceEqual) +} + +func hashmapInterfaceGet(m *hashmap, key interface{}, value unsafe.Pointer) bool { + hash := hashmapInterfaceHash(key) + return hashmapGet(m, unsafe.Pointer(&key), value, hash, hashmapInterfaceEqual) +} + +func hashmapInterfaceDelete(m *hashmap, key interface{}) { + hash := hashmapInterfaceHash(key) + hashmapDelete(m, unsafe.Pointer(&key), hash, hashmapInterfaceEqual) +} diff --git a/testdata/map.go b/testdata/map.go index c780c013..78834dd1 100644 --- a/testdata/map.go +++ b/testdata/map.go @@ -24,6 +24,11 @@ var testMapArrayKey = map[ArrayKey]int{ } var testmapIntInt = map[int]int{1: 1, 2: 4, 3: 9} +type namedFloat struct { + s string + f float32 +} + func main() { m := map[string]int{"answer": 42, "foo": 3} readMap(m, "answer") @@ -48,6 +53,44 @@ func main() { testMapArrayKey[arrKey] = 5555 println(testMapArrayKey[arrKey]) + // test maps with interface keys + itfMap := map[interface{}]int{ + 3.14: 3, + 8: 8, + uint8(8): 80, + "eight": 800, + [2]int{5, 2}: 52, + true: 1, + } + println("itfMap[3]:", itfMap[3]) // doesn't exist + println("itfMap[3.14]:", itfMap[3.14]) + println("itfMap[8]:", itfMap[8]) + println("itfMap[uint8(8)]:", itfMap[uint8(8)]) + println(`itfMap["eight"]:`, itfMap["eight"]) + println(`itfMap[[2]int{5, 2}]:`, itfMap[[2]int{5, 2}]) + println("itfMap[true]:", itfMap[true]) + delete(itfMap, 8) + println("itfMap[8]:", itfMap[8]) + + // test map with float keys + floatMap := map[float32]int{ + 42: 84, + } + println("floatMap[42]:", floatMap[42]) + println("floatMap[43]:", floatMap[43]) + delete(floatMap, 42) + println("floatMap[42]:", floatMap[42]) + + // test maps with struct keys + structMap := map[namedFloat]int{ + namedFloat{"tau", 6.28}: 5, + } + println(`structMap[{"tau", 6.28}]:`, structMap[namedFloat{"tau", 6.28}]) + println(`structMap[{"Tau", 6.28}]:`, structMap[namedFloat{"Tau", 6.28}]) + println(`structMap[{"tau", 3.14}]:`, structMap[namedFloat{"tau", 3.14}]) + delete(structMap, namedFloat{"tau", 6.28}) + println(`structMap[{"tau", 6.28}]:`, structMap[namedFloat{"tau", 6.28}]) + // test preallocated map squares := make(map[int]int, 200) testBigMap(squares, 100) @@ -79,7 +122,7 @@ func testBigMap(squares map[int]int, n int) { if len(squares) != i { println("unexpected length:", len(squares), "at i =", i) } - squares[i] = i*i + squares[i] = i * i for j := 0; j <= i; j++ { if v, ok := squares[j]; !ok || v != j*j { if !ok { diff --git a/testdata/map.txt b/testdata/map.txt index 66636d11..bffd0307 100644 --- a/testdata/map.txt +++ b/testdata/map.txt @@ -54,5 +54,20 @@ true false 0 42 4321 5555 +itfMap[3]: 0 +itfMap[3.14]: 3 +itfMap[8]: 8 +itfMap[uint8(8)]: 80 +itfMap["eight"]: 800 +itfMap[[2]int{5, 2}]: 52 +itfMap[true]: 1 +itfMap[8]: 0 +floatMap[42]: 84 +floatMap[43]: 0 +floatMap[42]: 0 +structMap[{"tau", 6.28}]: 5 +structMap[{"Tau", 6.28}]: 0 +structMap[{"tau", 3.14}]: 0 +structMap[{"tau", 6.28}]: 0 tested preallocated map tested growing of a map