怎么使用 context 控制 goroutine 的 cancel 的

是如何实现流程控制的

context 包除了上面介绍的 k-v 数据对外,还有三个函数,分别是:

context.WithCancel
context.WithDeadline
context.WithTimeout

这三个函数都会返回一个类型为CancelFunc的函数,通过源码可以知道,这几个函数里面返回的 cancel 分别是

// context.WithCancel
// c是cancelCtx
func() { c.cancel(true, Canceled) }

// context.WithDeadline
// c是timerCtx
WithCancel(parent)
// or
func() { c.cancel(true, Canceled) }

// context.WithTimeout
WithDeadline(parent, time.Now().Add(timeout))

从这里可以看出来,这三个函数的 cancel 是通过两个结构cancelCtxtimerCtx实现的

cancelCtx:

type cancelCtx struct {
    Context

    mu       sync.Mutex            // protects following fields
    done     chan struct{}         // created lazily, closed by first cancel call
    children map[canceler]struct{} // set to nil by the first cancel call
    err      error                 // set to non-nil by the first cancel call
}

timerCtx:(依赖了 cancelCtx)

type timerCtx struct {
    cancelCtx
    timer *time.Timer // Under cancelCtx.mu.

    deadline time.Time
}

那先看看 cancelCtx 是如何实现 cancel 的

cancelCtx 源码分析

// 实现了canceler接口
// 可以被cancel,如果他被cancel,那么所有的children都会被cancel
type cancelCtx struct {
    Context

    mu       sync.Mutex            // protects following fields
    done     chan struct{}         // created lazily, closed by first cancel call
    children map[canceler]struct{} // set to nil by the first cancel call
    err      error                 // set to non-nil by the first cancel call
}

func newCancelCtx(parent Context) cancelCtx {
    return cancelCtx{Context: parent}
}

// 返回一个阻塞的chan,并且赋给c.done
func (c *cancelCtx) Done() <-chan struct{} {
    c.mu.Lock()
    if c.done == nil {
        c.done = make(chan struct{})
    }
    d := c.done
    c.mu.Unlock()
    return d
}

// 返回cancelCtx中的err
func (c *cancelCtx) Err() error {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.err
}

func (c *cancelCtx) String() string {
    return fmt.Sprintf("%v.WithCancel", c.Context)
}

// cancel函数
// 关闭c.done这个阻塞chan,以及任何children的c.done
// 如果removeFromParent==true,将当前ctx和parent之间的链移除,即当前ctx不再是parent的child
// err 可以是context.Canceled,也可以是context.DeadlineExceeded
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
    if err == nil {
        panic("context: internal error: missing cancel error")
    }
    c.mu.Lock()
    if c.err != nil {
        c.mu.Unlock()
        // err只有这个函数可以设置,所以如果err已经!=nil了,那么就已经cancel了
        return // already canceled
    }
    c.err = err

    // 接下来5行将c.done设置为一个已经close的chan
    // 所以c.Done不再阻塞了
    if c.done == nil {
        c.done = closedchan
    } else {
        close(c.done)
    }

    // 循环cancel所有的children
    for child := range c.children {
        // NOTE: acquiring the child's lock while holding parent's lock.
        child.cancel(false, err)
    }
    c.children = nil
    c.mu.Unlock()

    // 如果有需要,移除与parent的链
    if removeFromParent {
        removeChild(c.Context, c)
    }
}

这个实现的点就在于创建了一个阻塞的 chan: ctx.done(),然后使用ctx.cancel()ctx.done设置为非阻塞(close)

如何使用 WithCancel

下面这段代码中,gen 函数死循环做 n++然后赋给 dst,dst 是一个 channel,可以通过 range 循环获取。即下面这段代码在for n := range gen(ctx)case dst <- n: n++之间循环了 5 次。

然后 main 函数退出,触发了 defer,执行 cancel。根据上面说的,cancel 会 close 掉 Done,所以case <-ctx.Done(): return达到执行条件,gen 函数中创建的 goroutinego func() { for { ... } }执行结束

因为这个函数本身就只有这么一点点就退出了,所以看不出WithCancel的优点,不过,如果是大型引用,使用WithCancel可以及早释放一些不必要的 goroutine

package main

import (
    "context"
    "fmt"
)

func main() {
    // dst是一个channel,不停的遍历会递增
    gen := func(ctx context.Context) <-chan int {
        dst := make(chan int)
        n := 1
        go func() {
            for {
                select {
                case <-ctx.Done():
                    return // returning not to leak the goroutine
                case dst <- n:
                    n++
                }
            }
        }()
        return dst
    }

    // 创建ctx与cancel函数
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel() // cancel when we are finished consuming integers

    // 遍历5次退出,执行defer
    for n := range gen(ctx) {
        fmt.Println(n)
        if n == 5 {
            break
        }
    }
}

WithCancel 源码分析

生成一个 cancelCtx{ },在 cancel 的时候会执行func() { c.cancel(true, Canceled) }

看一下 propagateCancel 函数

  • 假如 parent 是 context.Background(),那么会执行到else { go func() { .. } }
    • 这个时候,要么 parent 被 cancel 了,执行case <-parent.Done(): child.cancel(false, parent.Err()),把本 ctxcancel 然后退出
    • 要么等到了本 ctx 的 cancel:case <-child.Done()
  • 假如 parent 已经携带了控制信号
    • parent 也是一个 cancelCtx,然后调用func (c *cancelCtx) cancel(removeFromParent bool, err error)
    • parent 是一个 timerCtx,然后调用func (c *timerCtx) cancel(removeFromParent bool, err error)s
    • (当然这里也可能是一个 valueCtx,这里会递归找下去)
    • children 是怎么使用的?
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
    c := newCancelCtx(parent)
    propagateCancel(parent, &c)
    return &c, func() { c.cancel(true, Canceled) }
}

func propagateCancel(parent Context, child canceler) {
    if parent.Done() == nil {
        return // parent is never canceled
    }
    if p, ok := parentCancelCtx(parent); ok {
        p.mu.Lock()
        if p.err != nil {
            // parent has already been canceled
            child.cancel(false, p.err)
        } else {
            if p.children == nil {
                p.children = make(map[canceler]struct{})
            }
            p.children[child] = struct{}{}
        }
        p.mu.Unlock()
    } else {
        go func() {
            select {
            case <-parent.Done():
                child.cancel(false, parent.Err())
            case <-child.Done():
            }
        }()
    }
}

func parentCancelCtx(parent Context) (*cancelCtx, bool) {
    for {
        switch c := parent.(type) {
        case *cancelCtx:
            return c, true
        case *timerCtx:
            return &c.cancelCtx, true
        case *valueCtx:
            parent = c.Context
        default:
            return nil, false
        }
    }
}

然后看看 timerCtx 是如何实现 cancel 的

timerCtx 源码解析

cancel 有两个实现,一个是 cancelCtx 的,一个是 TimerCtx 的

这里的 timerCtx 的 cancel 的逻辑是

type timerCtx struct {
    cancelCtx
    timer *time.Timer // Under cancelCtx.mu.

    deadline time.Time
}

// 返回当前ctx的deadline
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
    return c.deadline, true
}

func (c *timerCtx) String() string {
    return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, time.Until(c.deadline))
}

func (c *timerCtx) cancel(removeFromParent bool, err error) {
    c.cancelCtx.cancel(false, err)
    if removeFromParent {
        // Remove this timerCtx from its parent cancelCtx's children.
        removeChild(c.cancelCtx.Context, c)
    }
    c.mu.Lock()
    if c.timer != nil {
        c.timer.Stop()
        c.timer = nil
    }
    c.mu.Unlock()
}

简析

代码实现详解

实现代码解析

// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
输入参数d是超时时间,
如果
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
    if cur, ok := parent.Deadline(); ok && cur.Before(d) {
        // The current deadline is already sooner than the new one.
        return WithCancel(parent)
    }
    c := &timerCtx{
        cancelCtx: newCancelCtx(parent),
        deadline:  d,
    }
    propagateCancel(parent, c)
    dur := time.Until(d)
    if dur <= 0 {
        c.cancel(true, DeadlineExceeded) // deadline has already passed
        return c, func() { c.cancel(true, Canceled) }
    }
    c.mu.Lock()
    defer c.mu.Unlock()
    if c.err == nil {
        c.timer = time.AfterFunc(dur, func() {
            c.cancel(true, DeadlineExceeded)
        })
    }
    return c, func() { c.cancel(true, Canceled) }
}

参考文章