From cdff0bd3ee55e25797fbf7f722d90940b55a9d96 Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Fri, 18 Oct 2019 12:04:15 -0400 Subject: [PATCH] add blocking select --- compiler/channel.go | 45 ++++++--- src/runtime/chan.go | 236 ++++++++++++++++++++++++++++++++++++++----- testdata/channel.go | 52 +++++++++- testdata/channel.txt | 1 + 4 files changed, 289 insertions(+), 45 deletions(-) diff --git a/compiler/channel.go b/compiler/channel.go index 9f033098..803075f1 100644 --- a/compiler/channel.go +++ b/compiler/channel.go @@ -136,7 +136,7 @@ func (c *Compiler) emitSelect(frame *Frame, expr *ssa.Select) llvm.Value { recvbuf := llvm.Undef(c.i8ptrType) if hasReceives { allocaType := llvm.ArrayType(c.ctx.Int8Type(), int(recvbufSize)) - recvbufAlloca := c.builder.CreateAlloca(allocaType, "select.recvbuf.alloca") + recvbufAlloca, _, _ := c.createTemporaryAlloca(allocaType, "select.recvbuf.alloca") recvbufAlloca.SetAlignment(recvbufAlign) recvbuf = c.builder.CreateGEP(recvbufAlloca, []llvm.Value{ llvm.ConstInt(c.ctx.Int32Type(), 0, false), @@ -146,7 +146,7 @@ func (c *Compiler) emitSelect(frame *Frame, expr *ssa.Select) llvm.Value { // Create the states slice (allocated on the stack). statesAllocaType := llvm.ArrayType(chanSelectStateType, len(selectStates)) - statesAlloca := c.builder.CreateAlloca(statesAllocaType, "select.states.alloca") + statesAlloca, statesI8, statesSize := c.createTemporaryAlloca(statesAllocaType, "select.states.alloca") for i, state := range selectStates { // Set each slice element to the appropriate channel. gep := c.builder.CreateGEP(statesAlloca, []llvm.Value{ @@ -161,19 +161,36 @@ func (c *Compiler) emitSelect(frame *Frame, expr *ssa.Select) llvm.Value { }, "select.states") statesLen := llvm.ConstInt(c.uintptrType, uint64(len(selectStates)), false) - // Convert the 'blocking' flag on this select into a LLVM value. - blockingInt := uint64(0) - if expr.Blocking { - blockingInt = 1 - } - blockingValue := llvm.ConstInt(c.ctx.Int1Type(), blockingInt, false) - // Do the select in the runtime. - results := c.createRuntimeCall("chanSelect", []llvm.Value{ - recvbuf, - statesPtr, statesLen, statesLen, // []chanSelectState - blockingValue, - }, "") + var results llvm.Value + if expr.Blocking { + // Stack-allocate operation structures. + // If these were simply created as a slice, they would heap-allocate. + chBlockAllocaType := llvm.ArrayType(c.getLLVMRuntimeType("channelBlockedList"), len(selectStates)) + chBlockAlloca, chBlockAllocaPtr, chBlockSize := c.createTemporaryAlloca(chBlockAllocaType, "select.block.alloca") + chBlockLen := llvm.ConstInt(c.uintptrType, uint64(len(selectStates)), false) + chBlockPtr := c.builder.CreateGEP(chBlockAlloca, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }, "select.block") + + results = c.createRuntimeCall("chanSelect", []llvm.Value{ + recvbuf, + statesPtr, statesLen, statesLen, // []chanSelectState + chBlockPtr, chBlockLen, chBlockLen, // []channelBlockList + }, "select.result") + + // Terminate the lifetime of the operation structures. + c.emitLifetimeEnd(chBlockAllocaPtr, chBlockSize) + } else { + results = c.createRuntimeCall("tryChanSelect", []llvm.Value{ + recvbuf, + statesPtr, statesLen, statesLen, // []chanSelectState + }, "select.result") + } + + // Terminate the lifetime of the states alloca. + c.emitLifetimeEnd(statesI8, statesSize) // The result value does not include all the possible received values, // because we can't load them in advance. Instead, the *ssa.Extract diff --git a/src/runtime/chan.go b/src/runtime/chan.go index dd3e4b2b..a4cb4748 100644 --- a/src/runtime/chan.go +++ b/src/runtime/chan.go @@ -37,11 +37,86 @@ func chanDebug(ch *channel) { } } +// channelBlockedList is a list of channel operations on a specific channel which are currently blocked. +type channelBlockedList struct { + // next is a pointer to the next blocked channel operation on the same channel. + next *channelBlockedList + + // t is the task associated with this channel operation. + // If this channel operation is not part of a select, then the pointer field of the state holds the data buffer. + // If this channel operation is part of a select, then the pointer field of the state holds the recieve buffer. + // If this channel operation is a receive, then the data field should be set to zero when resuming due to channel closure. + t *task + + // s is a pointer to the channel select state corresponding to this operation. + // This will be nil if and only if this channel operation is not part of a select statement. + // If this is a send operation, then the send buffer can be found in this select state. + s *chanSelectState + + // allSelectOps is a slice containing all of the channel operations involved with this select statement. + // Before resuming the task, all other channel operations on this select statement should be canceled by removing them from their corresponding lists. + allSelectOps []channelBlockedList +} + +// remove takes the current list of blocked channel operations and removes the specified operation. +// This returns the resulting list, or nil if the resulting list is empty. +// A nil receiver is treated as an empty list. +func (b *channelBlockedList) remove(old *channelBlockedList) *channelBlockedList { + if b == old { + return b.next + } + c := b + for ; c != nil && c.next != old; c = c.next { + } + if c != nil { + c.next = old.next + } + return b +} + +// detatch removes all other channel operations that are part of the same select statement. +// If the input is not part of a select statement, this is a no-op. +// This must be called before resuming any task blocked on a channel operation in order to ensure that it is not placed on the runqueue twice. +func (b *channelBlockedList) detach() { + if b.allSelectOps == nil { + // nothing to do + return + } + for i, v := range b.allSelectOps { + // cancel all other channel operations that are part of this select statement + if &b.allSelectOps[i] == b { + continue + } + if v.s.ch == nil { + continue + } + v.s.ch.blocked = v.s.ch.blocked.remove(&b.allSelectOps[i]) + if v.s.ch.blocked == nil { + if v.s.value == nil { + // recv operation + if v.s.ch.state != chanStateClosed { + v.s.ch.state = chanStateEmpty + } + } else { + // send operation + if v.s.ch.bufUsed == 0 { + // unbuffered channel + v.s.ch.state = chanStateEmpty + } else { + // buffered channel + v.s.ch.state = chanStateBuf + } + } + } + chanDebug(v.s.ch) + } +} + type channel struct { elementSize uintptr // the size of one value in this channel bufSize uintptr // size of buffer (in elements) state chanState - blocked *task + blocked *channelBlockedList bufHead uintptr // head index of buffer (next push index) bufTail uintptr // tail index of buffer (next pop index) bufUsed uintptr // number of elements currently in buffer @@ -58,6 +133,63 @@ func chanMake(elementSize uintptr, bufSize uintptr) *channel { } } +// resumeRX resumes the next receiver and returns the destination pointer. +// If the ok value is true, then the caller is expected to store a value into this pointer. +func (ch *channel) resumeRX(ok bool) unsafe.Pointer { + // pop a blocked goroutine off the stack + var b *channelBlockedList + b, ch.blocked = ch.blocked, ch.blocked.next + + // get destination pointer + dst := b.t.state().ptr + + if !ok { + // the result value is zero + memzero(dst, ch.elementSize) + b.t.state().data = 0 + } + + if b.s != nil { + // tell the select op which case resumed + b.t.state().ptr = unsafe.Pointer(b.s) + + // detach associated operations + b.detach() + } + + // push task onto runqueue + runqueuePushBack(b.t) + + return dst +} + +// resumeTX resumes the next sender and returns the source pointer. +// The caller is expected to read from the value in this pointer before yielding. +func (ch *channel) resumeTX() unsafe.Pointer { + // pop a blocked goroutine off the stack + var b *channelBlockedList + b, ch.blocked = ch.blocked, ch.blocked.next + + // get source pointer + src := b.t.state().ptr + + if b.s != nil { + // use state's source pointer + src = b.s.value + + // tell the select op which case resumed + b.t.state().ptr = unsafe.Pointer(b.s) + + // detach associated operations + b.detach() + } + + // push task onto runqueue + runqueuePushBack(b.t) + + return src +} + // push value to end of channel if space is available // returns whether there was space for the value in the buffer func (ch *channel) push(value unsafe.Pointer) bool { @@ -151,12 +283,10 @@ func (ch *channel) trySend(value unsafe.Pointer) bool { return false case chanStateRecv: // unblock reciever - receiver := unblockChain(&ch.blocked, nil) + dst := ch.resumeRX(true) // copy value to reciever - receiverState := receiver.state() - memcpy(receiverState.ptr, value, ch.elementSize) - receiverState.data = 1 // commaOk = true + memcpy(dst, value, ch.elementSize) // change state to empty if there are no more receivers if ch.blocked == nil { @@ -191,9 +321,11 @@ func (ch *channel) tryRecv(value unsafe.Pointer) (bool, bool) { // try to pop the value directly from the buffer if ch.pop(value) { // unblock next sender if applicable - if sender := unblockChain(&ch.blocked, nil); sender != nil { + if ch.blocked != nil { + src := ch.resumeTX() + // push sender's value into buffer - ch.push(sender.state().ptr) + ch.push(src) if ch.blocked == nil { // last sender unblocked - update state @@ -207,10 +339,12 @@ func (ch *channel) tryRecv(value unsafe.Pointer) (bool, bool) { } return true, true - } else if sender := unblockChain(&ch.blocked, nil); sender != nil { + } else if ch.blocked != nil { // unblock next sender if applicable + src := ch.resumeTX() + // copy sender's value - memcpy(value, sender.state().ptr, ch.elementSize) + memcpy(value, src, ch.elementSize) if ch.blocked == nil { // last sender unblocked - update state @@ -294,7 +428,10 @@ func chanSend(ch *channel, value unsafe.Pointer) { ch.state = chanStateSend senderState := sender.state() senderState.ptr = value - ch.blocked, senderState.next = sender, ch.blocked + ch.blocked = &channelBlockedList{ + next: ch.blocked, + t: sender, + } chanDebug(ch) yield() senderState.ptr = nil @@ -320,8 +457,11 @@ func chanRecv(ch *channel, value unsafe.Pointer) bool { receiver := getCoroutine() ch.state = chanStateRecv receiverState := receiver.state() - receiverState.ptr, receiverState.data = value, 0 - ch.blocked, receiverState.next = receiver, ch.blocked + receiverState.ptr, receiverState.data = value, 1 + ch.blocked = &channelBlockedList{ + next: ch.blocked, + t: receiver, + } chanDebug(ch) yield() ok := receiverState.data == 1 @@ -348,15 +488,9 @@ func chanClose(ch *channel) { runtimePanic("close channel during send") case chanStateRecv: // unblock all receivers with the zero value - for rx := unblockChain(&ch.blocked, nil); rx != nil; rx = unblockChain(&ch.blocked, nil) { - // get receiver state - state := rx.state() - - // store the zero value - memzero(state.ptr, ch.elementSize) - - // set the comma-ok value to false (channel closed) - state.data = 0 + ch.state = chanStateClosed + for ch.blocked != nil { + ch.resumeRX(false) } case chanStateEmpty, chanStateBuf: // Easy case. No available sender or receiver. @@ -371,7 +505,60 @@ func chanClose(ch *channel) { // // TODO: do this in a round-robin fashion (as specified in the Go spec) instead // of picking the first one that can proceed. -func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, blocking bool) (uintptr, bool) { +func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, ops []channelBlockedList) (uintptr, bool) { + if selected, ok := tryChanSelect(recvbuf, states); selected != ^uintptr(0) { + // one channel was immediately ready + return selected, ok + } + + // construct blocked operations + for i, v := range states { + ops[i] = channelBlockedList{ + next: v.ch.blocked, + t: getCoroutine(), + s: &states[i], + allSelectOps: ops, + } + v.ch.blocked = &ops[i] + if v.value == nil { + // recv + switch v.ch.state { + case chanStateEmpty: + v.ch.state = chanStateRecv + case chanStateRecv: + // already in correct state + default: + runtimePanic("invalid channel state") + } + } else { + // send + switch v.ch.state { + case chanStateEmpty: + v.ch.state = chanStateSend + case chanStateSend: + // already in correct state + case chanStateBuf: + // already in correct state + default: + runtimePanic("invalid channel state") + } + } + chanDebug(v.ch) + } + + // expose rx buffer + getCoroutine().state().ptr = recvbuf + getCoroutine().state().data = 1 + + // wait for one case to fire + yield() + + // figure out which one fired and return the ok value + return (uintptr(getCoroutine().state().ptr) - uintptr(unsafe.Pointer(&states[0]))) / unsafe.Sizeof(chanSelectState{}), getCoroutine().state().data != 0 +} + +// tryChanSelect is like chanSelect, but it does a non-blocking select operation. +func tryChanSelect(recvbuf unsafe.Pointer, states []chanSelectState) (uintptr, bool) { // See whether we can receive from one of the channels. for i, state := range states { if state.value == nil { @@ -389,8 +576,5 @@ func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, blocking bool) } } - if !blocking { - return ^uintptr(0), false - } - panic("unimplemented: blocking select") + return ^uintptr(0), false } diff --git a/testdata/channel.go b/testdata/channel.go index 3a5d9d1a..a504b690 100644 --- a/testdata/channel.go +++ b/testdata/channel.go @@ -1,8 +1,8 @@ package main import ( - "time" "runtime" + "time" ) // waitGroup is a small type reimplementing some of the behavior of sync.WaitGroup @@ -91,7 +91,7 @@ func main() { println("sum(100):", sum) // Test simple selects. - go selectDeadlock() // cannot use waitGroup here - never terminates + go selectDeadlock() // cannot use waitGroup here - never terminates wg.add(1) go selectNoOp() wg.wait() @@ -117,11 +117,10 @@ func main() { ch = make(chan int) wg.add(1) go func(ch chan int) { + runtime.Gosched() ch <- 55 wg.done() }(ch) - // not defined behavior, but we cant really fix this until select has been fixed - time.Sleep(time.Millisecond) select { case make(chan int) <- 3: println("unreachable") @@ -147,7 +146,6 @@ func main() { ch = make(chan int) wg.add(1) go fastreceiver(ch) - time.Sleep(time.Millisecond) select { case ch <- 235: println("select send") @@ -188,6 +186,50 @@ func main() { count++ } println("hybrid buffered channel recieve:", count) + + // test blocking selects + ch = make(chan int) + sch1 := make(chan int) + sch2 := make(chan int) + sch3 := make(chan int) + wg.add(3) + go func() { + defer wg.done() + time.Sleep(time.Millisecond) + sch1 <- 1 + }() + go func() { + defer wg.done() + time.Sleep(time.Millisecond) + sch2 <- 2 + }() + go func() { + defer wg.done() + // merge sch2 and sch3 into ch + for i := 0; i < 2; i++ { + var v int + select { + case v = <-sch1: + case v = <-sch2: + } + select { + case sch3 <- v: + panic("sent to unused channel") + case ch <- v: + } + } + }() + sum = 0 + for i := 0; i < 2; i++ { + select { + case sch3 <- sum: + panic("sent to unused channel") + case v := <-ch: + sum += v + } + } + wg.wait() + println("blocking select sum:", sum) } func send(ch chan<- int) { diff --git a/testdata/channel.txt b/testdata/channel.txt index 5415434b..b7036f2b 100644 --- a/testdata/channel.txt +++ b/testdata/channel.txt @@ -31,3 +31,4 @@ closed buffered channel recieve: 3 closed buffered channel recieve: 4 closed buffered channel recieve: 0 hybrid buffered channel recieve: 2 +blocking select sum: 3