fix: additional guarantees for concurrent calls

This commit is contained in:
Juan Pablo Villafáñez
2024-04-12 10:30:54 +02:00
parent 6ddc0addd3
commit ff346c2861
2 changed files with 102 additions and 32 deletions

View File

@@ -2,6 +2,7 @@ package runner
import (
"context"
"sync"
)
// GroupRunner represent a group of tasks that need to run together.
@@ -20,13 +21,17 @@ import (
// This means that, if a task finishes naturally, the rest of the task will
// asked to stop as well.
type GroupRunner struct {
runners map[string]*Runner
runners sync.Map
runnersCount int
isRunning bool
runningMutex sync.Mutex
}
// NewGroup will create a GroupRunner
func NewGroup() *GroupRunner {
return &GroupRunner{
runners: make(map[string]*Runner),
runners: sync.Map{},
runningMutex: sync.Mutex{},
}
}
@@ -34,12 +39,22 @@ func NewGroup() *GroupRunner {
//
// It's mandatory that each runner in the group has an unique id, otherwise
// there will be issues
// Adding new runners once the group starts will cause a panic
func (gr *GroupRunner) Add(r *Runner) {
if _, ok := gr.runners[r.ID]; ok {
// a runner already exist with that id
panic("Trying to add a runner with an existing Id in the group")
gr.runningMutex.Lock()
defer gr.runningMutex.Unlock()
if gr.isRunning {
panic("Adding a new runner after the group starts is forbidden")
}
// LoadOrStore will try to store the runner
if _, loaded := gr.runners.LoadOrStore(r.ID, r); loaded {
// there is already a runner with the same id, which is forbidden
panic("Trying to add a runner with an existing Id in the group")
} else {
gr.runnersCount++
}
gr.runners[r.ID] = r
}
// Run will execute all the tasks in the group at the same time.
@@ -61,12 +76,22 @@ func (gr *GroupRunner) Add(r *Runner) {
// Note that it is NOT expected for the finished task's stopper to be called
// in this case.
func (gr *GroupRunner) Run(ctx context.Context) []*Result {
// Set the flag inside the runningMutex to ensure we don't read the old value
// in the `Add` method and add a new runner when this method is being executed
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
// will panic
gr.runningMutex.Lock()
gr.isRunning = true
gr.runningMutex.Unlock()
results := make(map[string]*Result)
ch := make(chan *Result, len(gr.runners)) // no need to block writing results
for _, runner := range gr.runners {
runner.RunAsync(ch)
}
ch := make(chan *Result, gr.runnersCount) // no need to block writing results
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
r.RunAsync(ch)
return true
})
// wait for a result or for the context to be done
select {
@@ -77,30 +102,32 @@ func (gr *GroupRunner) Run(ctx context.Context) []*Result {
}
// interrupt the rest of the runners
for _, runner := range gr.runners {
if _, ok := results[runner.ID]; !ok {
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
if _, ok := results[r.ID]; !ok {
select {
case <-runner.Finished():
case <-r.Finished():
// No data should be sent through the channel, so we'd be
// here only if the channel is closed. This means the task
// has finished and we don't need to interrupt. We do
// nothing in this case
default:
runner.Interrupt()
r.Interrupt()
}
}
}
return true
})
// Having notified that the context has been finished, we still need to
// wait for the rest of the results
for i := len(results); i < len(gr.runners); i++ {
for i := len(results); i < gr.runnersCount; i++ {
result := <-ch
results[result.RunnerID] = result
}
close(ch)
values := make([]*Result, 0, len(gr.runners))
values := make([]*Result, 0, gr.runnersCount)
for _, val := range results {
values = append(values, val)
}
@@ -112,9 +139,33 @@ func (gr *GroupRunner) Run(ctx context.Context) []*Result {
// as it's available.
// Note that this method will finish as soon as all the tasks are running.
func (gr *GroupRunner) RunAsync(ch chan<- *Result) {
for _, runner := range gr.runners {
runner.RunAsync(ch)
}
// Set the flag inside the runningMutex to ensure we don't read the old value
// in the `Add` method and add a new runner when this method is being executed
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
// will panic
gr.runningMutex.Lock()
gr.isRunning = true
gr.runningMutex.Unlock()
// we need a secondary channel to receive the first result so we can
// interrupt the rest of the tasks
interCh := make(chan *Result, gr.runnersCount)
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
r.RunAsync(interCh)
return true
})
go func() {
result := <-interCh
gr.Interrupt()
ch <- result
for i := 1; i < gr.runnersCount; i++ {
result = <-interCh
ch <- result
}
}()
}
// Interrupt will execute the stopper function of ALL the tasks, which should
@@ -128,11 +179,13 @@ func (gr *GroupRunner) RunAsync(ch chan<- *Result) {
// try to stop just one task.
// If a task has finished, the corresponding stopper won't be called
func (gr *GroupRunner) Interrupt() {
for _, runner := range gr.runners {
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
select {
case <-runner.Finished():
case <-r.Finished():
default:
runner.Interrupt()
r.Interrupt()
}
}
return true
})
}

View File

@@ -2,6 +2,7 @@ package runner
import (
"context"
"sync/atomic"
)
// Runner represents the one executing a long running task, such as a server
@@ -10,10 +11,12 @@ import (
// Result that it will generated will contain the same ID, so we can
// know which runner provided which result.
type Runner struct {
ID string
fn Runable
interrupt Stopper
finished chan struct{}
ID string
fn Runable
interrupt Stopper
running atomic.Bool
interrupted atomic.Bool
finished chan struct{}
}
// New will create a new runner.
@@ -52,6 +55,12 @@ func New(id string, fn Runable, interrupt Stopper) *Runner {
// - Use context.WithDeadline(...) or context.WithTimeout(...) to run the task
// for a limited time
func (r *Runner) Run(ctx context.Context) *Result {
if !r.running.CompareAndSwap(false, true) {
// If not swapped, the task is already running.
// Running the same task multiple times is a bug, so we panic
panic("Runner with id " + r.ID + " was running twice")
}
ch := make(chan *Result)
go r.doTask(ch, true)
@@ -60,10 +69,9 @@ func (r *Runner) Run(ctx context.Context) *Result {
case result := <-ch:
return result
case <-ctx.Done():
r.interrupt()
r.Interrupt()
return <-ch
}
return <-ch
}
// RunAsync will execute the task associated to this runner asynchronously.
@@ -76,6 +84,12 @@ func (r *Runner) Run(ctx context.Context) *Result {
// To interrupt the running task, the only option is to call the `Interrupt`
// method at some point.
func (r *Runner) RunAsync(ch chan<- *Result) {
if !r.running.CompareAndSwap(false, true) {
// If not swapped, the task is already running.
// Running the same task multiple times is a bug, so we panic
panic("Runner with id " + r.ID + " was running twice")
}
go r.doTask(ch, false)
}
@@ -83,8 +97,11 @@ func (r *Runner) RunAsync(ch chan<- *Result) {
// in order for it to finish.
// The stopper will be called immediately, although it's expected the
// consequences to take a while (task might need a while to stop)
// This method will be called only once. Further calls won't do anything
func (r *Runner) Interrupt() {
r.interrupt()
if r.interrupted.CompareAndSwap(false, true) {
r.interrupt()
}
}
// Finished will return a receive-only channel that can be used to know when