源码面前无秘密 | Golang标准库 sync.WaitGroup

2,592 阅读5分钟

本文章基于 go1.14.6

用法

  1. 新的goroutine创建前调用WaitGroup.Add(1)并在执行结束时调用WaitGroup.Done()
  2. 在阻塞等待的goroutine内调用WaitGroup.Wait(),当调用返回后,它所等待的goroutine一定执行完毕了

使用示例

package main

import (
	"fmt"
	"sync"
)

func main() {
	wg := sync.WaitGroup{}

	for i := 0; i < 3; i++ {
		wg.Add(1)
		go func() {
			fmt.Println("hello world!")
			wg.Done()
		}()
	}

	wg.Wait()
	fmt.Println("main done!")
}

以上代码展示了一个比较典型的sync.WaitGroup使用场景,即一个goroutine等待其他若干个goroutine的结束.

就以上代码而言,sync.WaitGroup确保了在for中的所有goroutine执行fmt.Println("hello world!")结束后,再执行fmt.Println("main done!"),然后退出程序.

源码阅读

以下源码位与 /src/sync/waitgroup.go

想要了解sync.WaitGroup的实现以及使用过程中可能遇到的坑,就必须对源码进行深入理解.sync.WaitGroup的源码并不多但会考虑很多并发情况,总体难度适中,很适合go初学者作为go源码阅读的起点.

结构体

type WaitGroup struct {
    // 这是sync库中常用的一个结构体,内部是空实现.
    // 当结构体被初始化并使用后,如果对改值进行copy则go vet检查代码时会警告
	noCopy noCopy

	// 关于这个成员变量,官方解释很复杂,对于初学者而言,只需要知道其中3个uint32的值分别存储了:
    // 1. 被Add/Done方法操作的计数器,Add表示加,Done表示减,当计数器为0时,Wait会立即返回
    // 2. 正在Wait()处阻塞的goroutine个数
    // 3. 用于等待与唤醒的信号量,信号量在之后有具体介绍
	state1 [3]uint32
}

方法

state

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

该方法解析WaitGroup结构体中的state1返回两个指针

其中statep所指的uint64变量中高32位存储了计数器,低32位存储了此时正在Wait处阻塞的goroutine个数.

semap指向用于唤醒和等待的信号量,具体见后续讲解.

Add/Done

Done其实就是WaitGroup.Add(-1)的别名,故不单独讲解

func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state()
	if race.Enabled {
		//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
	}
    
    //给 计数器 原子增加delta
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	v := int32(state >> 32) //v是本次增加后计数器的值
	w := uint32(state)      //w是此时正在Wait()处出阻塞等待的goroutine个数
	if race.Enabled && delta > 0 && v == int32(delta) {
		//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
	}
    
    //在正常情况下,每次创建goroutine,计数器v加1;
    //该goroutine个数执行完毕调用Done使计数器v减1; 所以计数器v一定>=0
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
    
    //此处检查一种情况: 当第一个Add()与Wait()并发调用到此处时,如果Wait()被先执行,
    //此时计数器v为0,Wait()会直接返回;
    //如果Add()先执行,此时计数器v>0之后的Wait()执行时会被阻塞等待;
    //此时出现了由于数据竞争导致的结果不一致,所以panic提示调用者.
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    
    //当计数器v>0时,表示不应该对正在Wait()的goroutine执行任何操作,应直接返回
    //当w==0是表示没有正在Wait()的goroutine,应直接返回
	if v > 0 || w == 0 {
		return
	}
    
    //注意: 在没有错误的情况下,代码执行到此处时(v==0 && w>0)
    
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
    
    // 此处检测两种并发情况:
    // 1. 当w从>0的值变为0时,禁止新的Add()调用,因为此时新建的goroutine可能不会被Wait
    // 2. 当v==0时,一定不会有新的Wait()使得w++,因为在Wait()内部看到v==0就会立即返回,详见下文
    // 以上任何一种情况出发时,*statep都会被更新,导致(state!=*statep)
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    
    // 代码执行到此处时,我们要根据w的数量,依次唤醒在Wait()处等待的goroutine,所以此时
    // 可以将w也置为0,即v=0,w=0,所以*statep就可以赋值为0
	*statep = 0
	for ; w != 0; w-- {
        //runtime_Semrelease是runtime包的内建函数,
        //可以简单理解为唤醒runtime_Semacquire处阻塞等待的goroutine
		runtime_Semrelease(semap, false, 0)
	}
}

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait

func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	if race.Enabled {
		//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
	}
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32) //计数器
        w := uint32(state)      //正在Wait()的goroutine数量
		if v == 0 {
			// 计数器v==0就不需要等待直接返回即可
			if race.Enabled {
				//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
			}
			return
		}
		// 尝试原子地将w+1,如果失败就在for循环内不停尝试,类似乐观锁
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			if race.Enabled && w == 0 {
				//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
			}
            //此处v>0,所以需要等待v减为0;当v通过Add()减为0时,会唤醒此处的等待
			runtime_Semacquire(semap)
            
            //简单检测一个并发问题: 当Wait()被唤醒后,应满足v==0&&w==0
            //否则一定出现了Wait()返回前,Add()被并发调用的问题
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				//...新手遇到race.Enabled代码块时直接跳过即可,仅用于-race编译时,与主逻辑无关
			}
			return
		}
	}
}

总结

  1. WaitGroup被初始化且使用(例如调用Add())后,不能做值拷贝,如果需要传递因采用地址传递方式
  2. 在Wait()执行后不应再调用Add()增加计数器,只能调用Done()使计数器减为0