diff --git a/compiler.go b/compiler.go index a69655b6..8525aced 100644 --- a/compiler.go +++ b/compiler.go @@ -2241,25 +2241,26 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { slice = c.builder.CreateInsertValue(slice, sliceCap, 2, "") return slice, nil case *ssa.Next: + rangeVal := expr.Iter.(*ssa.Range).X + llvmRangeVal, err := c.parseExpr(frame, rangeVal) + if err != nil { + return llvm.Value{}, err + } + it, err := c.parseExpr(frame, expr.Iter) + if err != nil { + return llvm.Value{}, err + } if expr.IsString { - return llvm.Value{}, errors.New("todo: next: string") + fn := c.mod.NamedFunction("runtime.stringNext") + return c.builder.CreateCall(fn, []llvm.Value{llvmRangeVal, it}, "range.next"), nil } else { // map fn := c.mod.NamedFunction("runtime.hashmapNext") - it, err := c.parseExpr(frame, expr.Iter) - if err != nil { - return llvm.Value{}, err - } - rangeMap := expr.Iter.(*ssa.Range).X - m, err := c.parseExpr(frame, rangeMap) - if err != nil { - return llvm.Value{}, err - } - llvmKeyType, err := c.getLLVMType(rangeMap.Type().(*types.Map).Key()) + llvmKeyType, err := c.getLLVMType(rangeVal.Type().(*types.Map).Key()) if err != nil { return llvm.Value{}, err } - llvmValueType, err := c.getLLVMType(rangeMap.Type().(*types.Map).Elem()) + llvmValueType, err := c.getLLVMType(rangeVal.Type().(*types.Map).Elem()) if err != nil { return llvm.Value{}, err } @@ -2268,7 +2269,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { mapKeyPtr := c.builder.CreateBitCast(mapKeyAlloca, c.i8ptrType, "range.keyptr") mapValueAlloca := c.builder.CreateAlloca(llvmValueType, "range.value") mapValuePtr := c.builder.CreateBitCast(mapValueAlloca, c.i8ptrType, "range.valueptr") - ok := c.builder.CreateCall(fn, []llvm.Value{m, it, mapKeyPtr, mapValuePtr}, "range.next") + ok := c.builder.CreateCall(fn, []llvm.Value{llvmRangeVal, it, mapKeyPtr, mapValuePtr}, "range.next") tuple := llvm.Undef(llvm.StructType([]llvm.Type{llvm.Int1Type(), llvmKeyType, llvmValueType}, false)) tuple = c.builder.CreateInsertValue(tuple, ok, 0, "") @@ -2285,22 +2286,22 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { frame.phis = append(frame.phis, Phi{expr, phi}) return phi, nil case *ssa.Range: + var iteratorType llvm.Type switch typ := expr.X.Type().Underlying().(type) { - case *types.Basic: - // string - return llvm.Value{}, errors.New("todo: range: string") + case *types.Basic: // string + iteratorType = c.mod.GetTypeByName("runtime.stringIterator") case *types.Map: - iteratorType := c.mod.GetTypeByName("runtime.hashmapIterator") - it := c.builder.CreateAlloca(iteratorType, "range.it") - zero, err := getZeroValue(iteratorType) - if err != nil { - return llvm.Value{}, nil - } - c.builder.CreateStore(zero, it) - return it, nil + iteratorType = c.mod.GetTypeByName("runtime.hashmapIterator") default: panic("unknown type in range: " + typ.String()) } + it := c.builder.CreateAlloca(iteratorType, "range.it") + zero, err := getZeroValue(iteratorType) + if err != nil { + return llvm.Value{}, nil + } + c.builder.CreateStore(zero, it) + return it, nil case *ssa.Slice: if expr.Max != nil { return llvm.Value{}, errors.New("todo: full slice expressions (with max): " + expr.Type().String()) diff --git a/src/runtime/string.go b/src/runtime/string.go index d8c2dee4..540c6410 100644 --- a/src/runtime/string.go +++ b/src/runtime/string.go @@ -12,6 +12,12 @@ type _string struct { length lenType } +// The iterator state for a range over a string. +type stringIterator struct { + byteindex lenType + rangeindex lenType +} + // Return true iff the strings match. //go:nobounds func stringEqual(x, y string) bool { @@ -75,6 +81,18 @@ func stringFromUnicode(x rune) _string { return _string{ptr: (*byte)(unsafe.Pointer(&array)), length: length} } +// Iterate over a string. +// Returns (ok, key, value). +func stringNext(s string, it *stringIterator) (bool, int, rune) { + if len(s) <= int(it.byteindex) { + return false, 0, 0 + } + r, length := decodeUTF8(s, it.byteindex) + it.byteindex += length + it.rangeindex += 1 + return true, int(it.rangeindex), r +} + // Convert a Unicode code point into an array of bytes and its length. func encodeUTF8(x rune) ([4]byte, lenType) { // https://stackoverflow.com/questions/6240055/manually-converting-unicode-codepoints-into-utf-8-and-utf-16 @@ -102,3 +120,31 @@ func encodeUTF8(x rune) ([4]byte, lenType) { return [4]byte{0xef, 0xbf, 0xbd, 0}, 3 } } + +// Decode a single UTF-8 character from a string. +//go:nobounds +func decodeUTF8(s string, index lenType) (rune, lenType) { + remaining := lenType(len(s)) - index // must be >= 1 before calling this function + x := s[index] + switch { + case x&0x80 == 0x00: // 0xxxxxxx + return rune(x), 1 + case x&0xe0 == 0xc0: // 110xxxxx + if remaining < 2 { + return 0xfffd, 1 + } + return (rune(x&0x1f) << 6) | (rune(s[index+1]) & 0x3f), 2 + case x&0xf0 == 0xe0: // 1110xxxx + if remaining < 3 { + return 0xfffd, 1 + } + return (rune(x&0x0f) << 12) | ((rune(s[index+1]) & 0x3f) << 6) | (rune(s[index+2]) & 0x3f), 3 + case x&0xf8 == 0xf0: // 11110xxx + if remaining < 4 { + return 0xfffd, 1 + } + return (rune(x&0x07) << 18) | ((rune(s[index+1]) & 0x3f) << 12) | ((rune(s[index+2]) & 0x3f) << 6) | (rune(s[index+3]) & 0x3f), 4 + default: + return 0xfffd, 1 + } +}