いろいろ備忘録日記

主に .NET とか Go とか Flutter とか Python絡みのメモを公開しています。

Goメモ-498 (シンプルなラッチ)(CountdownLatch, CountdownEvent)

関連記事

GitHub - devlights/blog-summary: ブログ「いろいろ備忘録日記」のまとめ

概要

以下、自分用のメモです。

たまに、.NETにあるCountdownEventやJavaのCountDownLatchのような機構が欲しいときがあります。

Goの標準ライブラリには無いから、よく以下のような構造体を何回も作っている気がするので、ついでにここにメモとして残しておきます。

サンプル

countdownlatch.go

package main

import (
    "sync"
    "sync/atomic"
)

// CountdownLatch は、C#のCountdownEventやJavaのCountDownLatchと同様の機能を提供する構造体です.
type CountdownLatch struct {
    count atomic.Int32
    mutex sync.Mutex
    cond  *sync.Cond
}

// NewCountdownLatch は、指定されたカウント数でCountdownLatchを初期化します.
func NewCountdownLatch(initialCount int) *CountdownLatch {
    if initialCount < 0 {
        panic("初期カウントは0以上である必要があります")
    }

    var (
        latch CountdownLatch
    )
    latch.count.Store(int32(initialCount))
    latch.cond = sync.NewCond(&latch.mutex)

    return &latch
}

// Signal は、カウントを1減らします.
// 戻り値として、カウントダウンが満了したかどうかを返します.
func (me *CountdownLatch) Signal() bool {
    return me.SignalCount(1)
}

// SignalCount は、指定された数だけカウントを減らします.
// 戻り値として、カウントダウンが満了したかどうかを返します.
func (me *CountdownLatch) SignalCount(count int) bool {
    if count <= 0 {
        return false
    }

    me.mutex.Lock()
    defer me.mutex.Unlock()

    newCount := me.count.Add(-int32(count))
    if newCount <= 0 {
        me.cond.Broadcast()
        return true
    }

    return false
}

// Wait は、カウントが0になるまでブロックします.
func (me *CountdownLatch) Wait() {
    me.mutex.Lock()
    defer me.mutex.Unlock()

    for me.count.Load() > 0 {
        me.cond.Wait()
    }
}

// CurrentCount は、現在のカウント値を返します.
func (me *CountdownLatch) CurrentCount() int {
    me.mutex.Lock()
    defer me.mutex.Unlock()

    return int(me.count.Load())
}

main.go

package main

import (
    "context"
    "errors"
    "log"
    "sync"
    "time"
)

const (
    MainTimeout = 20 * time.Second
    ProcTimeout = 10 * time.Second
)

var (
    ErrMainTooSlow = errors.New("(MAIN) TOO SLOW")
    ErrProcTooSlow = errors.New("(PROC) TOO SLOW")
)

func init() {
    log.SetFlags(0)
}

func main() {
    var (
        rootCtx          = context.Background()
        mainCtx, mainCxl = context.WithTimeoutCause(rootCtx, MainTimeout, ErrMainTooSlow)
        procCtx          = run(mainCtx)
        err              error
    )
    defer mainCxl()

    select {
    case <-mainCtx.Done():
        err = context.Cause(mainCtx)
    case <-procCtx.Done():
        if err = context.Cause(procCtx); errors.Is(err, context.Canceled) {
            err = nil
        }
    }

    if err != nil {
        log.Fatal(err)
    }
}

func run(pCtx context.Context) context.Context {
    var (
        ctx, cxl = context.WithCancelCause(pCtx)
    )

    go func() {
        cxl(proc(ctx))
    }()
    go func() {
        <-time.After(ProcTimeout)
        cxl(ErrProcTooSlow)
    }()

    return ctx
}

func proc(_ context.Context) error {
    const (
        latchCount = 3
    )
    var (
        latch = NewCountdownLatch(latchCount)
        wg    sync.WaitGroup
    )

    for i := range 5 {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()

            log.Printf("[%2d] 待機開始", i)
            latch.Wait()
            log.Printf("[%2d] 待機解除", i)
        }(i)
    }

    for range 3 {
        <-time.After(time.Second)

        log.Printf("現在のカウント: %d\n", latch.CurrentCount())
        latch.Signal()
    }

    wg.Wait()

    return nil
}

Taskfile.yml

# https://taskfile.dev

version: '3'

tasks:
  default:
    cmds:
      - go run .

実行

$ task
[ 4] 待機開始
[ 0] 待機開始
[ 1] 待機開始
[ 3] 待機開始
[ 2] 待機開始
現在のカウント: 3
現在のカウント: 2
現在のカウント: 1
[ 2] 待機解除
[ 1] 待機解除
[ 3] 待機解除
[ 4] 待機解除
[ 0] 待機解除

参考情報

github.com

learn.microsoft.com

docs.oracle.com

Goのおすすめ書籍


過去の記事については、以下のページからご参照下さい。

サンプルコードは、以下の場所で公開しています。