いろいろ備忘録日記

主に .NET とか Go とか Python絡みのメモを公開しています。

Goメモ-62 (sync.WaitGroupとerrgroupパッケージ)

概要

忘れない内にメモメモ。

複数の非同期処理の完了を待ち合わせたりするときに、標準ライブラリの sync.WaitGroup をよく利用しますが

WaitGroupさんは完了待機するにはとても便利なんですが、エラー伝播する場合は自分で作り込む必要があります。

非同期処理でエラーは付き物なので、もうちょっと便利なものがあればいいなーって思ってたら以下のものがありました。

godoc.org

WaitGroupの機能に加えて

  • 発生したエラーを教えてくれる
  • ctx.Contextをサポート

という機能を持っています。めっちゃ便利ですねこのパッケージ。

有名なパッケージみたいなので、知ってる人多いと思いますが、以下サンプルです。

インストール

普通に go get するだけです。

go get golang.org/x/sync/errgroup

サンプル

sync.WaitGroupでエラー伝播

以下は、sync.WaitGroupを使って非同期処理側で発生したエラーを取得するサンプル。

package cmpwaitgroup

import (
    "fmt"
    "math/rand"
    "sync"
    "time"

    "github.com/devlights/try-golang/output"
    "github.com/devlights/try-golang/util/enumerable"
)

func init() {
    rand.Seed(time.Now().UnixNano())
}

// ErrWithWaitGroup は、標準ライブラリ sync.WaitGroup でエラー情報を呼び元に伝播させるサンプルです.
func ErrWithWaitGroup() error {
    var (
        loopRange = enumerable.NewRange(1, 6)
        waitGrp   = sync.WaitGroup{}
        errorCh   = make(chan error)
    )

    // ----------------------------------------------------------------------------------------
    // sync.WaitGroup は、待ち合わせを担当するためのものであるため
    // 非同期処理側で発生したエラーを収集しておくような機能は持っていない
    // そのため、エラーが発生した場合の処理をユーザ側で作り込む必要がある
    // ----------------------------------------------------------------------------------------
    for loopRange.Next() {
        waitGrp.Add(1)

        go func(i int) {
            defer waitGrp.Done()

            prefix := fmt.Sprintf("[go func %02d]", i)
            output.Stderrl(prefix, "start")
            defer output.Stderrl(prefix, "end")

            err := randomErr(prefix)
            if err != nil {
                output.Stderrl(prefix, "\tERROR!!")
                errorCh <- err
            }

        }(loopRange.Current())
    }

    go func() {
        waitGrp.Wait()
        close(errorCh)
    }()

    for err := range errorCh {
        output.Stdoutl("[err]", err)
    }

    return nil
}

func randomErr(message string) error {
    i := rand.Intn(100)
    if i > 30 {
        return fmt.Errorf("randomErr [%d][%s]", i, message)
    }

    return nil
}

実行すると以下のようになります。

$ make run
ENTER EXAMPLE NAME: errgrp_error_with_waitgroup
[Name] "errgrp_error_with_waitgroup"
[err]                randomErr [52][[go func 02]]
[err]                randomErr [78][[go func 03]]
[err]                randomErr [52][[go func 04]]
[err]                randomErr [87][[go func 05]]
[go func 01]         start
[go func 01]         end
[go func 02]         start
[go func 02]           ERROR!!
[go func 02]         end
[go func 03]         start
[go func 03]           ERROR!!
[go func 03]         end
[go func 04]         start
[go func 04]           ERROR!!
[go func 04]         end
[go func 05]         start
[go func 05]           ERROR!!
[go func 05]         end

errgroup.Groupでエラー伝播

以下は、errgroup.Groupを使って非同期処理側で発生したエラーを取得するサンプル。

package cmpwaitgroup

import (
    "fmt"

    "github.com/devlights/try-golang/output"
    "github.com/devlights/try-golang/util/enumerable"
    "golang.org/x/sync/errgroup"
)

// ErrWithErrGroup は、拡張ライブラリ golang.org/x/sync/errgroup でエラー情報を呼び元に伝播させるサンプルです.
//
// https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc#example-Group-JustErrors
func ErrWithErrGroup() error {
    var (
        loopRange = enumerable.NewRange(1, 6)
        waitGrp   = errgroup.Group{}
    )

    // ----------------------------------------------------------------------------------------
    // errgroup.Group は、sync.WaitGroup のように待ち合わせを行う機能に加えて
    // 発生したエラーを収集し、呼び元に返すことが可能となっている
    // 返してくれるエラーは、最初に発生したエラー情報となっている
    //
    // 利用方法は、sync.WaitGroup とは少し異なり Go(func() error) メソッドに
    // 非同期実行部分を渡して処理する形となっている. 内部で goroutine 化して呼び出してくれるので
    // 呼び元で go を付与する必要はない.
    //
    // 待ち合わせを実施したい箇所で、Wait() メソッドを呼び出すことにより非同期処理全部が完了するまで
    // 呼び元をブロックする。
    // ----------------------------------------------------------------------------------------
    for loopRange.Next() {
        i := loopRange.Current()
        waitGrp.Go(func() error {
            prefix := fmt.Sprintf("[go func %02d]", i)
            output.Stderrl(prefix, "start")
            defer output.Stderrl(prefix, "end")

            err := randomErr(prefix)
            if err != nil {
                output.Stderrl(prefix, "\tERROR!!")
            }

            return err
        })
    }

    // 複数の goroutine にて、複数のエラーが発生している場合でも取得できるのは最初に発生したエラーとなる
    if err := waitGrp.Wait(); err != nil {
        output.Stdoutl("[err]", err)
    }

    return nil
}

実行すると以下のようになります。

$ make run
[Name] "errgrp_error_with_errgroup"
[err]                randomErr [98][[go func 02]]
[go func 03]         start
[go func 03]         end
[go func 02]         start
[go func 02]            ERROR!!
[go func 02]         end
[go func 04]         start
[go func 04]         end
[go func 01]         start
[go func 01]            ERROR!!
[go func 05]         start
[go func 01]         end
[go func 05]            ERROR!!
[go func 05]         end

[err] と出力されている行が一行しか出ていないのは、errgroupさんの仕様が最初にエラーが発生したものを通知するというものだからです。

errgroupでctx.Contextを使ったサンプル

以下、errgroup.WithContext() を利用したサンプルです。ctx.Contextを利用できるので、エラーが発生したら残りの処理をキャンセルが可能になります。

package withcontext

import (
    "context"
    "fmt"
    "time"

    "github.com/devlights/try-golang/output"
    "github.com/devlights/try-golang/util/enumerable"
    "golang.org/x/sync/errgroup"
)

// ErrGroupWithContext は、拡張ライブラリ golang.org/x/sync/errgroup で ctx.Context を含めた利用方法についてのサンプルです.
//
// https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc#example-Group-Parallel
func ErrGroupWithContext() error {
    // 利用するコンテキスト関連
    var (
        rootCtx           = context.Background()
        errGrp, errGrpCtx = errgroup.WithContext(rootCtx)
    )

    // その他の情報
    var (
        loopRange = enumerable.NewRange(1, 6)
    )

    // ----------------------------------------------------------------------------------------
    // errgroup.WithContext(ctx.Context) を利用することで、コンテキスト情報も管理することが可能となる
    // ここで取得した ctx.Context は、以下の場合にキャンセル状態となる。つまり、 <-ctx.Done() が通るようになる
    //   - どれかの非同期処理が最初に non-nil な戻り値を返したとき
    //   - 最初に Wait() が返ったとき
    // なので、非同期処理内でこのコンテキストを見張ることにより、どこかの処理でエラーが発生した場合に
    // まだ処理が始まっていない or 現在処理中の処理 をまとめてキャンセルすることができる
    // (現在処理中のものをキャンセルするためには、定周期で ctx.Done() を確認するポーリング処理を作り込む必要がある)
    // ----------------------------------------------------------------------------------------
    for loopRange.Next() {
        i := loopRange.Current()

        errGrp.Go(func() error {
            prefix := fmt.Sprintf("[go func %02d]", i)

            // キャンセルすることを確認したいので、意図的に少しだけ隙間を空ける
            time.Sleep(1 * time.Microsecond)

            select {
            case <-errGrpCtx.Done():
                // だれかが初めにエラーを返した時点でこのコンテキストがキャンセルされる
                // main-goroutine側はWait() を呼び出しているため、この Wait() が return した
                // タイミングでもコンテキストはキャンセルされる.
                output.Stderrl(prefix, "CANCEL!!")
                return errGrpCtx.Err()
            default:
                output.Stderrl(prefix, "start")
                defer output.Stderrl(prefix, "end")

                err := raiseErr(prefix)
                if err != nil {
                    output.Stderrl(prefix, "\tERROR!!")
                }

                return err
            }
        })
    }

    if err := errGrp.Wait(); err != nil {
        output.Stdoutl("[err]", err)
    }

    return nil
}

func raiseErr(message string) error {
    return fmt.Errorf("raiseErr [%s]", message)
}

実行すると以下のようになります。

$ make run
[Name] "errgrp_with_context"
[err]                raiseErr [[go func 02]]
[go func 02]         start
[go func 02]           ERROR!!
[go func 02]         end
[go func 05]         CANCEL!!
[go func 01]         CANCEL!!
[go func 03]         CANCEL!!
[go func 04]         CANCEL!!

errgroupでパイプライン処理

errgroupパッケージのexampleにあったのでついでに。別にerrgroupでなくても出来るのですが、パイプライン中でエラーが発生した場合に残りの作業を止めるってのはerrgroupパッケージ使ったほうが楽ですね。

以下のサンプルでは、指定したディレクトリ配下の goファイル のmd5チェックサムを算出して出力しています。errgroupのexampleとほぼ同じですが、、w

package pipeline

import (
    "context"
    "crypto/md5"
    "fmt"
    "io/ioutil"
    "os"
    "path/filepath"
    "strings"

    "github.com/devlights/try-golang/output"
    "golang.org/x/sync/errgroup"
)

type (
    md5result struct {
        path     string
        checkSum [md5.Size]byte
        name     string
    }
)

// ErrGroupWithPipeline は、拡張ライブラリ golang.org/x/sync/errgroup でパイプライン処理を行っているサンプルです.
//
// https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc#example-Group-Pipeline
func ErrGroupWithPipeline() error {
    // 利用するコンテキスト関連
    var (
        rootCtx           = context.Background()
        errGrp, errGrpCtx = errgroup.WithContext(rootCtx)
    )

    var (
        filePathCh = make(chan string)
        md5Ch      = make(chan md5result)
    )

    // 1st ステージ
    // 配下の *.go ファイルをリストアップ
    errGrp.Go(func() error {
        defer close(filePathCh)
        return filepath.Walk(".", func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }

            if info.IsDir() {
                return nil
            }

            if strings.HasSuffix(info.Name(), ".go") {
                filePathCh <- path
            }

            select {
            case <-errGrpCtx.Done():
                return errGrpCtx.Err()
            default:
                return nil
            }
        })
    })

    // 2nd ステージ
    // リストアップされたファイルを順次 md5 checksum していく
    // 10個のgoroutineを並行処理させる
    for i := 0; i < 10; i++ {
        goroutineIndex := i + 1
        errGrp.Go(func() error {
            var (
                name  = fmt.Sprintf("goroutine-%02d", goroutineIndex)
                count = 0
            )

            for p := range filePathCh {
                data, err := ioutil.ReadFile(p)
                if err != nil {
                    return err
                }

                checksum := md5.Sum(data)
                result := md5result{
                    path:     p,
                    checkSum: checksum,
                    name:     name,
                }

                select {
                case md5Ch <- result:
                    count++
                case <-errGrpCtx.Done():
                    return errGrpCtx.Err()
                }
            }

            return nil
        })
    }

    // 3rd ステージ
    // 1st, 2nd の処理完了を検知して結果用のチャネルである md5ch を閉じる
    go func() {
        _ = errGrp.Wait()
        close(md5Ch)
    }()

    // final ステージ
    // 結果出力
    for r := range md5Ch {
        cs := fmt.Sprintf("%x", r.checkSum)
        output.Stdoutl(r.name, cs, r.path)
    }

    // エラー判定
    // Wait() は、複数回呼んでも構わない.
    // 上の呼び出しは、処理の区切りを判定するために最初にエラーが返ったタイミング、もしくは、全部処理が終わったことを
    // 検知するためのもの。以下は、再度呼び出してエラーがあれば出力するためのもの
    if err := errGrp.Wait(); err != nil {
        output.Stdoutl("err", err)
    }

    return nil
}

実行すると以下のようになります。

$ make run
[Name] "errgrp_with_pipeline"
goroutine-04         448c1c63332f6da79ebe33e0bd9b8b7a books\bootcamp\doc.go
goroutine-06         0ae6280d37033b6bb5003003f3bd8197 books\doc.go
goroutine-07         480b37c58c529d205187870c8df4bec5 books\examples.go
goroutine-05         a5d5a4cc0de9a52c6cf319c2b316263f books\concurrency\doc.go
goroutine-08         e8bf439432e53029734351b1400005f4 books\go101\doc.go
goroutine-09         7f3027cf39166de947185615893924e6 books\startingGo\doc.go
goroutine-10         38a1e06cc4bf4126020fbd7820f8a395 builder\builder.go
goroutine-02         3337eec13f4861ce8c732e4e23fdf581 builder\doc.go
・
・
・
割愛

ちゃんとそれぞれのgoroutineがバラバラで 1st ステージ から流れてきたデータを処理してくれていますね。

1stステージと2ndステージの間でfan-outさせています。3rdステージがミソで、これ入れないとデッドロックします。

ついでに結果をExcelに出力するサンプル

ついでなので、Excelに最終結果を出力するようにしたサンプルです。

package main

import (
    "context"
    "crypto/md5"
    "flag"
    "fmt"
    "io/ioutil"
    "os"
    "path/filepath"
    "regexp"

    "github.com/devlights/goxcel"
    "golang.org/x/sync/errgroup"
)

type (
    args struct {
        directory string
        pattern   string
        output    string
    }

    md5result struct {
        path     string
        checkSum [md5.Size]byte
        name     string
    }
)

var (
    cmdArgs = args{}
)

const (
    _NumberOf2ndStageGoroutines = 10
)

func main() {
    os.Exit(run())
}

func run() int {

    flag.StringVar(&cmdArgs.directory, "d", ".", "対象ディレクトリ")
    flag.StringVar(&cmdArgs.pattern, "p", ".*", "対象ファイルパターン")
    flag.StringVar(&cmdArgs.output, "o", "", "出力ファイルパス")
    flag.Parse()

    if cmdArgs.output == "" {
        flag.Usage()
        return 2
    }

    if cmdArgs.directory == "" {
        cmdArgs.directory = "."
    }

    if cmdArgs.pattern == "" {
        cmdArgs.pattern = "*.*"
    }

    var (
        rootCtx           = context.Background()
        errGrp, errGrpCtx = errgroup.WithContext(rootCtx)
    )

    var (
        filePathCh = make(chan string)
        md5Ch      = make(chan md5result)
    )

    // 1st stage
    start1stStage(errGrp, errGrpCtx, filePathCh)

    // 2nd stage
    start2ndStage(errGrp, errGrpCtx, filePathCh, md5Ch)

    // 3rd stage
    start3rdStage(errGrp, md5Ch)

    // final stage
    execFinalStage(md5Ch)

    if err := errGrp.Wait(); err != nil {
        fmt.Println(err)
        return 1
    }

    return 0
}

func start1stStage(errGrp *errgroup.Group, ctx context.Context, filePathCh chan<- string) {

    errGrp.Go(func() error {
        defer close(filePathCh)
        return filepath.Walk(cmdArgs.directory, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }

            if info.IsDir() {
                return nil
            }

            match, _ := regexp.Match(cmdArgs.pattern, []byte(info.Name()))
            if match {
                filePathCh <- path
            }

            select {
            case <-ctx.Done():
                return ctx.Err()
            default:
                return nil
            }
        })
    })
}

func start2ndStage(errGrp *errgroup.Group, ctx context.Context, filePathCh <-chan string, md5Ch chan<- md5result) {

    for i := 0; i < _NumberOf2ndStageGoroutines; i++ {
        goroutineIndex := i + 1
        errGrp.Go(func() error {
            var (
                name  = fmt.Sprintf("goroutine-%02d", goroutineIndex)
                count = 0
            )

            for p := range filePathCh {
                data, err := ioutil.ReadFile(p)
                if err != nil {
                    return err
                }

                checksum := md5.Sum(data)
                result := md5result{
                    path:     p,
                    checkSum: checksum,
                    name:     name,
                }

                select {
                case <-ctx.Done():
                    return ctx.Err()
                case md5Ch <- result:
                    count++
                }
            }

            return nil
        })
    }
}

func start3rdStage(errGrp *errgroup.Group, md5Ch chan md5result) {
    go func() {
        _ = errGrp.Wait()
        close(md5Ch)
    }()
}

func execFinalStage(md5Ch <-chan md5result) {

    quitGoxcelFn, _ := goxcel.InitGoxcel()
    defer quitGoxcelFn()

    g, gr, _ := goxcel.NewGoxcel()
    defer gr()

    _ = g.SetDisplayAlerts(false)
    _ = g.SetVisible(false)

    wbs, _ := g.Workbooks()
    wb, wbr, _ := wbs.Add()
    defer wbr()

    wss, _ := wb.WorkSheets()
    ws, _ := wss.Item(1)

    row := 1
    for r := range md5Ch {
        fileNameCell, _ := ws.Cells(row, 1)
        _ = fileNameCell.SetValue(r.path)

        md5ChecksumCell, _ := ws.Cells(row, 2)
        _ = md5ChecksumCell.SetValue(fmt.Sprintf("%x", r.checkSum))

        row++
    }

    _ = wb.SaveAs(cmdArgs.output)
}

過去の記事については、以下のページからご参照下さい。

  • いろいろ備忘録日記まとめ

devlights.github.io

サンプルコードは、以下の場所で公開しています。

  • いろいろ備忘録日記サンプルソース置き場

github.com

github.com

github.com