いろいろ備忘録日記

主に .NET とか Go とか Flutter とか Python絡みのメモを公開しています。

Goメモ-500 (シンプルな循環式バリア)(CyclicBarrier)

関連記事

Goメモ-498 (シンプルなラッチ)(CountdownLatch, CountdownEvent) - いろいろ備忘録日記

Goメモ-499 (シンプルなゲート)(Gate, カウント1のラッチ) - いろいろ備忘録日記

GitHub - devlights/blog-summary: ブログ「いろいろ備忘録日記」のまとめ

概要

以下、自分用のメモです。

前回、ラッチとゲートについてのサンプルをメモしたので、ついでにバリアもメモ。

サンプル

cyclicbarrier.go

package main

import (
    "context"
    "sync"
)

// CyclicBarrier は、指定された数のgoroutineが特定のポイントで待ち合わせることができる同期プリミティブです.
// すべてのgoroutineが到達するか、コンテキストがキャンセルされるまでブロックします.
type (
    CyclicBarrier struct {
        parties int        // 待ち合わせが必要なgoroutineの数
        waiting int        // 現在待機中のgoroutine数
        mutex   sync.Mutex // 内部状態を保護するためのmutex
        cond    *sync.Cond // 条件変数
        ctx     context.Context
        cancel  context.CancelFunc
        barrier chan struct{} // バリアチャネル
    }
)

// NewCyclicBarrier は、新しいCyclicBarrierを作成します.
// partiesには、同期が必要なgoroutineの数を指定します.
func NewCyclicBarrier(parties int) *CyclicBarrier {
    if parties <= 0 {
        panic("parties must be greater than 0")
    }

    var (
        ctx, cancel = context.WithCancel(context.Background())
        barrier     = &CyclicBarrier{
            parties: parties,
            ctx:     ctx,
            cancel:  cancel,
            barrier: make(chan struct{}),
        }
    )
    barrier.cond = sync.NewCond(&barrier.mutex)

    return barrier
}

// Await は、他のgoroutineが到達するのを待機します.
// すべてのgoroutineが到達すると、バリアが解放され、カウンターがリセットされます.
// コンテキストがキャンセルされた場合はエラーを返します.
func (me *CyclicBarrier) Await() error {
    me.mutex.Lock()
    defer me.mutex.Unlock()

    // コンテキストが既にキャンセルされているかチェック
    if me.ctx.Err() != nil {
        return me.ctx.Err()
    }

    var (
        generation = me.barrier // 現在の世代を記録(バリア条件が満了した場合、次のチャネルに切り替わるため)
    )
    me.waiting++
    if me.waiting == me.parties {
        // 最後のgoroutineが到達
        me.waiting = 0

        close(me.barrier)
        me.barrier = make(chan struct{}) // 新しい世代のためのチャネルを作成
        me.cond.Broadcast()              // 待機解除

        return nil
    }

    // 他のgoroutineを待つ
    for generation == me.barrier && me.ctx.Err() == nil {
        me.cond.Wait()
    }

    if me.ctx.Err() != nil {
        return me.ctx.Err()
    }

    return nil
}

// Reset は、バリアをリセットし、待機中のすべてのgoroutineをキャンセルします.
func (me *CyclicBarrier) Reset() {
    me.mutex.Lock()
    defer me.mutex.Unlock()

    // 現在待機しているgoroutineを解除
    me.cancel()

    var (
        ctx, cancel = context.WithCancel(context.Background())
    )
    me.ctx = ctx
    me.cancel = cancel
    me.waiting = 0

    // 世代入れ替え
    close(me.barrier)
    me.barrier = make(chan struct{})
    me.cond.Broadcast()
}

// GetNumberWaiting は、現在待機中のgoroutineの数を返します.
func (me *CyclicBarrier) GetNumberWaiting() int {
    me.mutex.Lock()
    defer me.mutex.Unlock()

    return me.waiting
}

// GetParties は、同期に必要なgoroutineの数を返します.
func (me *CyclicBarrier) GetParties() int {
    return me.parties
}

main.go

package main

import (
    "context"
    "errors"
    "log"
    "sync"
    "time"
)

const (
    MainTimeout = 20 * time.Second
    ProcTimeout = 10 * time.Second
)

var (
    ErrMainTooSlow = errors.New("(MAIN) TOO SLOW")
    ErrProcTooSlow = errors.New("(PROC) TOO SLOW")
)

func init() {
    log.SetFlags(log.Ltime)
}

func main() {
    var (
        rootCtx          = context.Background()
        mainCtx, mainCxl = context.WithTimeoutCause(rootCtx, MainTimeout, ErrMainTooSlow)
        procCtx          = run(mainCtx)
        err              error
    )
    defer mainCxl()

    select {
    case <-mainCtx.Done():
        err = context.Cause(mainCtx)
    case <-procCtx.Done():
        if err = context.Cause(procCtx); errors.Is(err, context.Canceled) {
            err = nil
        }
    }

    if err != nil {
        log.Fatal(err)
    }
}

func run(pCtx context.Context) context.Context {
    var (
        ctx, cxl = context.WithCancelCause(pCtx)
    )

    go func() {
        cxl(proc(ctx))
    }()
    go func() {
        <-time.After(ProcTimeout)
        cxl(ErrProcTooSlow)
    }()

    return ctx
}

func proc(_ context.Context) error {
    const (
        WORKER_COUNT = 3
    )
    var (
        barrier = NewCyclicBarrier(WORKER_COUNT)
        wg      sync.WaitGroup
    )

    // 3つのワーカーを起動し、全員揃ったら先に進むを繰り返す
    for i := 0; i < WORKER_COUNT; i++ {
        wg.Add(1)
        go worker(i+1, &wg, barrier)
    }

    wg.Wait()

    return nil
}

func worker(id int, wg *sync.WaitGroup, barrier *CyclicBarrier) {
    defer wg.Done()

    for i := 0; i < 3; i++ {
        log.Printf("Worker-[%2d] 準備作業 %2d週目", id, i+1)
        time.Sleep(time.Duration(id) * time.Second)

        log.Printf("Worker-[%2d] 待機開始", id)
        {
            if err := barrier.Await(); err != nil {
                log.Printf("Worker-[%2d] エラー: %v", id, err)
                return
            }
        }
        log.Printf("Worker-[%2d] 待機解除", id)
    }
}

Taskfile.yml

# https://taskfile.dev

version: '3'

tasks:
  default:
    cmds:
      - go run .

実行

$ task
task: [default] go run .
08:07:01 Worker-[ 3] 準備作業  1週目
08:07:01 Worker-[ 1] 準備作業  1週目
08:07:01 Worker-[ 2] 準備作業  1週目
08:07:02 Worker-[ 1] 待機開始
08:07:03 Worker-[ 2] 待機開始
08:07:04 Worker-[ 3] 待機開始
08:07:04 Worker-[ 3] 待機解除
08:07:04 Worker-[ 3] 準備作業  2週目
08:07:04 Worker-[ 2] 待機解除
08:07:04 Worker-[ 2] 準備作業  2週目
08:07:04 Worker-[ 1] 待機解除
08:07:04 Worker-[ 1] 準備作業  2週目
08:07:05 Worker-[ 1] 待機開始
08:07:06 Worker-[ 2] 待機開始
08:07:07 Worker-[ 3] 待機開始
08:07:07 Worker-[ 3] 待機解除
08:07:07 Worker-[ 3] 準備作業  3週目
08:07:07 Worker-[ 2] 待機解除
08:07:07 Worker-[ 2] 準備作業  3週目
08:07:07 Worker-[ 1] 待機解除
08:07:07 Worker-[ 1] 準備作業  3週目
08:07:08 Worker-[ 1] 待機開始
08:07:09 Worker-[ 2] 待機開始
08:07:10 Worker-[ 3] 待機開始
08:07:10 Worker-[ 3] 待機解除
08:07:10 Worker-[ 1] 待機解除
08:07:10 Worker-[ 2] 待機解除

ゲームでよくある所定の人数分キャラがその場所に立ったら、上にブーンって登ってまた戻って来るエレベータみたいな感じ。

参考情報

github.com

Goのおすすめ書籍


過去の記事については、以下のページからご参照下さい。

サンプルコードは、以下の場所で公開しています。