From ff346c28617f2bfa525b460d4f98fc58a966ea13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Pablo=20Villaf=C3=A1=C3=B1ez?= Date: Fri, 12 Apr 2024 10:30:54 +0200 Subject: [PATCH] fix: additional guarantees for concurrent calls --- ocis-pkg/runner/grouprunner.go | 101 +++++++++++++++++++++++++-------- ocis-pkg/runner/runner.go | 33 ++++++++--- 2 files changed, 102 insertions(+), 32 deletions(-) diff --git a/ocis-pkg/runner/grouprunner.go b/ocis-pkg/runner/grouprunner.go index f77f769b6..4b225206b 100644 --- a/ocis-pkg/runner/grouprunner.go +++ b/ocis-pkg/runner/grouprunner.go @@ -2,6 +2,7 @@ package runner import ( "context" + "sync" ) // GroupRunner represent a group of tasks that need to run together. @@ -20,13 +21,17 @@ import ( // This means that, if a task finishes naturally, the rest of the task will // asked to stop as well. type GroupRunner struct { - runners map[string]*Runner + runners sync.Map + runnersCount int + isRunning bool + runningMutex sync.Mutex } // NewGroup will create a GroupRunner func NewGroup() *GroupRunner { return &GroupRunner{ - runners: make(map[string]*Runner), + runners: sync.Map{}, + runningMutex: sync.Mutex{}, } } @@ -34,12 +39,22 @@ func NewGroup() *GroupRunner { // // It's mandatory that each runner in the group has an unique id, otherwise // there will be issues +// Adding new runners once the group starts will cause a panic func (gr *GroupRunner) Add(r *Runner) { - if _, ok := gr.runners[r.ID]; ok { - // a runner already exist with that id - panic("Trying to add a runner with an existing Id in the group") + gr.runningMutex.Lock() + defer gr.runningMutex.Unlock() + + if gr.isRunning { + panic("Adding a new runner after the group starts is forbidden") + } + + // LoadOrStore will try to store the runner + if _, loaded := gr.runners.LoadOrStore(r.ID, r); loaded { + // there is already a runner with the same id, which is forbidden + panic("Trying to add a runner with an existing Id in the group") + } else { + gr.runnersCount++ } - gr.runners[r.ID] = r } // Run will execute all the tasks in the group at the same time. @@ -61,12 +76,22 @@ func (gr *GroupRunner) Add(r *Runner) { // Note that it is NOT expected for the finished task's stopper to be called // in this case. func (gr *GroupRunner) Run(ctx context.Context) []*Result { + // Set the flag inside the runningMutex to ensure we don't read the old value + // in the `Add` method and add a new runner when this method is being executed + // Note that if multiple `Run` or `RunAsync` happens, the underlying runners + // will panic + gr.runningMutex.Lock() + gr.isRunning = true + gr.runningMutex.Unlock() + results := make(map[string]*Result) - ch := make(chan *Result, len(gr.runners)) // no need to block writing results - for _, runner := range gr.runners { - runner.RunAsync(ch) - } + ch := make(chan *Result, gr.runnersCount) // no need to block writing results + gr.runners.Range(func(_, value any) bool { + r := value.(*Runner) + r.RunAsync(ch) + return true + }) // wait for a result or for the context to be done select { @@ -77,30 +102,32 @@ func (gr *GroupRunner) Run(ctx context.Context) []*Result { } // interrupt the rest of the runners - for _, runner := range gr.runners { - if _, ok := results[runner.ID]; !ok { + gr.runners.Range(func(_, value any) bool { + r := value.(*Runner) + if _, ok := results[r.ID]; !ok { select { - case <-runner.Finished(): + case <-r.Finished(): // No data should be sent through the channel, so we'd be // here only if the channel is closed. This means the task // has finished and we don't need to interrupt. We do // nothing in this case default: - runner.Interrupt() + r.Interrupt() } } - } + return true + }) // Having notified that the context has been finished, we still need to // wait for the rest of the results - for i := len(results); i < len(gr.runners); i++ { + for i := len(results); i < gr.runnersCount; i++ { result := <-ch results[result.RunnerID] = result } close(ch) - values := make([]*Result, 0, len(gr.runners)) + values := make([]*Result, 0, gr.runnersCount) for _, val := range results { values = append(values, val) } @@ -112,9 +139,33 @@ func (gr *GroupRunner) Run(ctx context.Context) []*Result { // as it's available. // Note that this method will finish as soon as all the tasks are running. func (gr *GroupRunner) RunAsync(ch chan<- *Result) { - for _, runner := range gr.runners { - runner.RunAsync(ch) - } + // Set the flag inside the runningMutex to ensure we don't read the old value + // in the `Add` method and add a new runner when this method is being executed + // Note that if multiple `Run` or `RunAsync` happens, the underlying runners + // will panic + gr.runningMutex.Lock() + gr.isRunning = true + gr.runningMutex.Unlock() + + // we need a secondary channel to receive the first result so we can + // interrupt the rest of the tasks + interCh := make(chan *Result, gr.runnersCount) + gr.runners.Range(func(_, value any) bool { + r := value.(*Runner) + r.RunAsync(interCh) + return true + }) + + go func() { + result := <-interCh + gr.Interrupt() + + ch <- result + for i := 1; i < gr.runnersCount; i++ { + result = <-interCh + ch <- result + } + }() } // Interrupt will execute the stopper function of ALL the tasks, which should @@ -128,11 +179,13 @@ func (gr *GroupRunner) RunAsync(ch chan<- *Result) { // try to stop just one task. // If a task has finished, the corresponding stopper won't be called func (gr *GroupRunner) Interrupt() { - for _, runner := range gr.runners { + gr.runners.Range(func(_, value any) bool { + r := value.(*Runner) select { - case <-runner.Finished(): + case <-r.Finished(): default: - runner.Interrupt() + r.Interrupt() } - } + return true + }) } diff --git a/ocis-pkg/runner/runner.go b/ocis-pkg/runner/runner.go index 0f4084cf6..e80ad194d 100644 --- a/ocis-pkg/runner/runner.go +++ b/ocis-pkg/runner/runner.go @@ -2,6 +2,7 @@ package runner import ( "context" + "sync/atomic" ) // Runner represents the one executing a long running task, such as a server @@ -10,10 +11,12 @@ import ( // Result that it will generated will contain the same ID, so we can // know which runner provided which result. type Runner struct { - ID string - fn Runable - interrupt Stopper - finished chan struct{} + ID string + fn Runable + interrupt Stopper + running atomic.Bool + interrupted atomic.Bool + finished chan struct{} } // New will create a new runner. @@ -52,6 +55,12 @@ func New(id string, fn Runable, interrupt Stopper) *Runner { // - Use context.WithDeadline(...) or context.WithTimeout(...) to run the task // for a limited time func (r *Runner) Run(ctx context.Context) *Result { + if !r.running.CompareAndSwap(false, true) { + // If not swapped, the task is already running. + // Running the same task multiple times is a bug, so we panic + panic("Runner with id " + r.ID + " was running twice") + } + ch := make(chan *Result) go r.doTask(ch, true) @@ -60,10 +69,9 @@ func (r *Runner) Run(ctx context.Context) *Result { case result := <-ch: return result case <-ctx.Done(): - r.interrupt() + r.Interrupt() + return <-ch } - - return <-ch } // RunAsync will execute the task associated to this runner asynchronously. @@ -76,6 +84,12 @@ func (r *Runner) Run(ctx context.Context) *Result { // To interrupt the running task, the only option is to call the `Interrupt` // method at some point. func (r *Runner) RunAsync(ch chan<- *Result) { + if !r.running.CompareAndSwap(false, true) { + // If not swapped, the task is already running. + // Running the same task multiple times is a bug, so we panic + panic("Runner with id " + r.ID + " was running twice") + } + go r.doTask(ch, false) } @@ -83,8 +97,11 @@ func (r *Runner) RunAsync(ch chan<- *Result) { // in order for it to finish. // The stopper will be called immediately, although it's expected the // consequences to take a while (task might need a while to stop) +// This method will be called only once. Further calls won't do anything func (r *Runner) Interrupt() { - r.interrupt() + if r.interrupted.CompareAndSwap(false, true) { + r.interrupt() + } } // Finished will return a receive-only channel that can be used to know when