3

優化重構 Worker Pool 程式碼

 2 years ago
source link: https://blog.wu-boy.com/2022/06/refactor-worker-pool-source-code/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

優化重構 Worker Pool 程式碼

 Posted on June 7, 2022  |  3 minutes  |  557 words  |  Appleboy
logo

最近看到 Go 語言一段代碼,我覺得有很大的優化空間,也將此優化的過程跟想法分享給大家,也許每個人優化的方向不同,大家可以把代碼整個看完後,不要繼續往下看,先想看看是否有優化的空間,下述程式碼本身沒有任何問題,執行過程不會出現任何錯誤。

先說明底下案例在做什麼,相信大家都有聽過在 Go 語言內要實現 Worker Pools 機制相當簡單,看到 ExecuteAll 函式就是讓開發者可以自訂同時間開多少個 Goroutine 來平行執行工作,第二個參數可以自訂義工作內容是什麼。

package executor

import (
  "context"
  "fmt"
  "runtime"
  "sync"
)

type TaskFunc func(ctx context.Context) error

func ExecuteAll(numCPU int, tasks ...TaskFunc) error {
  var err error
  ctx, cancel := context.WithCancel(context.Background())
  defer cancel()

  wg := sync.WaitGroup{}
  wg.Add(len(tasks))

  if numCPU == 0 {
    numCPU = runtime.NumCPU()
  }
  fmt.Println("numCPU:", numCPU)
  queue := make(chan TaskFunc, numCPU)

  // Spawn the executer
  for i := 0; i < numCPU; i++ {
    go func() {
      for task := range queue {
        fmt.Println("get task")
        if err == nil {
          taskErr := task(ctx)
          if taskErr != nil {
            err = taskErr
            cancel()
          }
        }
        wg.Done()
      }
    }()
  }

  // Add tasks to queue
  for _, task := range tasks {
    queue <- task
  }
  close(queue)

  // wait for all task done
  wg.Wait()
  return err
}

三大優化方向

大家看完上述程式碼,是否心裡已經有想法該怎麼優化,或者是有看出什麼問題?首先我看到第一個疑問

wg := sync.WaitGroup{}
wg.Add(len(tasks))

為什麼是從 Task 數量來放進去 WatiGroup,理論上我們是要控制開多少個 Goroutine,而不是將 Task 數量全部執行完畢,才結束程式。

第二個問題就是這段代碼會 blocking 在最下面的讀取 Task 塞入 Queue 變數上,大家看到底下代碼,宣告的是根據想要開多少 Goroutine 的 buffer 大小 Channel。舉例假設使用 4 core,然後 100 個 Task,每個 Task 執行需要 10 秒,此時塞 4 個 Task 進去 Queue 後,會被順利讀取出來 4 個 task,接著 Queue 又被塞滿 4 個 task 後,就無法再繼續將新的 Task 放入,故程式就會被 blocking。

 queue := make(chan TaskFunc, numCPU)
//
// 中間省略一堆代碼
//
//
// Add tasks to queue
for _, task := range tasks {
  queue <- task
}
close(queue)

// wait for all task done
wg.Wait()

先看看讀取 Task 的 goroutine for 迴圈,由於只要有一個 Task 執行錯誤,就會將錯誤設定給全域變數 err,但是可以看到如果有 1 萬的 Task,此迴圈後續還是將每個 Task 都讀取出來,完全沒有使用到 Context 重要的 Channel 功能。更多 Context 用法可以參考這篇『用 10 分鐘了解 Go 語言 context package 使用場景及介紹

  go func() {
    for task := range queue {
      fmt.Println("get task")
      if err == nil {
        taskErr := task(ctx)
        if taskErr != nil {
          err = taskErr
          cancel()
        }
      }
      wg.Done()
    }
  }()

改寫 sync.WaitGroup 使用方式

根據上面提到的三個問題,底下來一一解決,首先這段程式碼目的是開多個平行化處理的 Goroutine,故結束前必須要等待全部 Goroutine 執行完成才讓主程式繼續往下走,所以使用 sync.WaitGroup 可以改成根據目前設定多少平行處理來決定

if numCPU == 0 {
  numCPU = runtime.NumCPU()
}

wg := sync.WaitGroup{}
wg.Add(numCPU)

改寫 buffer channel 大小

上面有提到 Channel 大小原本使用要同步處理多少工作當作 Buffer 大小,但是只要 Task 數量大於 Buffer 大小,就會出現 blocking,故這邊可以改成底下

queue := make(chan TaskFunc, len(tasks))

// Add tasks to queue
for _, task := range tasks {
  queue <- task
}
close(queue)

將 Buffer 大小改成跟 Task 數量一致,藉此透過 for 迴圈先將 Task 塞到 Channel 內,並關閉 Channel 即可。

讀取 Task 流程

此函式目的就是平行跑多個 Task,遇到任何錯誤,就中斷流程,並返回錯誤訊息,故需要透過 Context Cancel 特性來改寫原本流程

for i := 0; i < numCPU; i++ {
  go func() {
    defer wg.Done()
    for {
      select {
      case task, ok := <-queue:
        if ctx.Err() != nil || !ok {
          return
        }
        fmt.Println("get task")
        if e := task(ctx); e != nil {
          err = e
          cancel()
        }
      case <-ctx.Done():
        return
      }
    }
  }()
}

當 Task 出現錯誤時,會將錯誤訊息放到全域變數 err 內,並且執行 cancel(),此時 for 在讀取下一個 Job 時,就可以透過 <-ctx.Done()ctx.Err() 方式來終止程式執行,這樣才不會多跑了很多次迴圈

Worker Pool 網路上寫法千奇百種,優化的方式每個人想的也是不一樣,透過這樣的練習可以加深自己對於 Go Channel 特性。原本的程式碼都可以正常執行沒問題,只是看到覺得有幾個地方可以優化,故寫在這邊紀錄重構想法,可以讓剛入門 Go 語言的朋友們參考。


See also


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK