3

爱折腾的WaitGroup

 1 year ago
source link: https://colobu.com/2022/08/30/waitgroup-to-love-to-toss/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

WaitGroup是Go并发编程中经常使用的做任务编排的一个一个并发原语。看起来它只有几个简单的方法,使用起来比较简单。实际上,WaitGroup的内部实现也陆陆续续改变了好几次,主要是针对它的字段的原子操作不断的做优化。

WaitGroup原始的实现

最早的WaitGroup的实现如下:

type WaitGroup struct {
m Mutex
counter int32
waiters int32
sema *uint32
func (wg *WaitGroup) Add(delta int) {
v := atomic.AddInt32(&wg.counter, int32(delta))
if v < 0 {
panic("sync: negative WaitGroup count")
if v > 0 || atomic.LoadInt32(&wg.waiters) == 0 {
return
wg.m.Lock()
for i := int32(0); i < wg.waiters; i++ {
runtime_Semrelease(wg.sema)
wg.waiters = 0
wg.sema = nil
wg.m.Unlock()

它的实现字段的意义比较明确,但是实现还略显粗糙,比如sema采用指针实现。

之后将字段counterwaiters合并。为了要保证64bit的原子操作8位对齐, 需要找到state1的对齐点。 sema去掉了指针实现。

type WaitGroup struct {
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers do not ensure it. So we allocate 12 bytes and then use
// the aligned 8 bytes in them as state.
state1 [12]byte
sema uint32
func (wg *WaitGroup) state() *uint64 {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1))
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[4]))

后来,WaitGroup实现如下,并稳定下来:

type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers do not ensure it. So we allocate 12 bytes and then use
// the aligned 8 bytes in them as state, and the other 4 as storage
// for the sema.
state1 [3]uint32
// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]

state1 和 sema字段合并成一个字段state1, 这个数组是uint32,四字节。所以要么是第一个元素就是8byte对齐的,要么就是第二个元素是8byte对齐的。找到对齐的8byte,剩余的4byte就作为sema。

这个实现没有问题,就是有些饶人。因为你不得不检查state1的对齐,才能确定哪个是counter和waiters,哪个是sema。

问个问题: WaitGroup的waiter数最多是多大?

Go 1.18的改变

在Go 1.18中, WaitGroup又做了改变,针对64bit架构的环境,编译器保证伟uint64类型的字段按照8byte对齐。

type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers only guarantee that 64-bit fields are 32-bit aligned.
// For this reason on 32 bit architectures we need to check in state()
// if state1 is aligned or not, and dynamically "swap" the field order if
// needed.
state1 uint64
state2 uint32

当然为了兼容32bit的架构,还是需要判断一下对齐:

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
// state1 is 64-bit aligned: nothing to do.
return &wg.state1, &wg.state2
} else {
// state1 is 32-bit aligned but not 64-bit aligned: this means that
// (&state1)+4 is 64-bit aligned.
state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
return (*uint64)(unsafe.Pointer(&state[1])), &state[0]

总体上来说,在linux/amd64环境中,此修改会带来 9%~30%的性能提升。

Go 1.20中的改变

优化还未万。在Go 1.19中, Russ Cox实现了atomic.Uint64,它在64bit架构和32bit架构下都是8byte对齐的,为啥呢?因为它有一个"尚方宝剑":align64

// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
_ noCopy
_ align64
v uint64

64bit架构下没有问题,32bit架构下看到这个字段,Go编译器就会自动把它按照8byte对齐,这是一个约定。你在你的package下定义struct加上align64是没有用的。
不过如果你也想让你的struct 8byte对齐的话,你可以使用下面的技术:

import "sync/atomic"
type T struct {
_ [0]atomic.Int64 // 占用0字节,但是隐含字段是8byte对齐的
x uint64 // x是8byte对齐的

这样依赖, WaitGroup的实现又可以简化成了:

type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32

也不必实现单独的state()方法了。直接使用state字段即可(去除了race代码):

func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
if v > 0 || w == 0 {
return
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK