Skip to content

Commit

Permalink
Fix concurrent map access in task queues (#717)
Browse files Browse the repository at this point in the history
Signed-off-by: Mikhail Scherba <mikhail.scherba@flant.com>
  • Loading branch information
miklezzzz authored Feb 18, 2025
1 parent f632bb6 commit f37bbec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
3 changes: 1 addition & 2 deletions pkg/shell-operator/manager_events_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ func (m *ManagerEventsHandler) Start() {

m.taskQueues.DoWithLock(func(tqs *queue.TaskQueueSet) {
for _, resTask := range tailTasks {
q := tqs.GetByName(resTask.GetQueueName())
if q == nil {
if q := tqs.Queues[resTask.GetQueueName()]; q == nil {
log.Error("Possible bug!!! Got task for queue but queue is not created yet.",
slog.String("queueName", resTask.GetQueueName()),
slog.String("description", resTask.GetDescription()))
Expand Down
25 changes: 19 additions & 6 deletions pkg/task/queue/queue_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ type TaskQueueSet struct {
ctx context.Context
cancel context.CancelFunc

m sync.Mutex
m sync.RWMutex
Queues map[string]*TaskQueue
}

func NewTaskQueueSet() *TaskQueueSet {
return &TaskQueueSet{
Queues: make(map[string]*TaskQueue),
m: sync.Mutex{},
MainName: MainQueueName,
}
}
Expand All @@ -45,23 +44,31 @@ func (tqs *TaskQueueSet) WithMetricStorage(mstor *metricstorage.MetricStorage) {
}

func (tqs *TaskQueueSet) Stop() {
tqs.m.RLock()
if tqs.cancel != nil {
tqs.cancel()
}

tqs.m.RUnlock()
}

func (tqs *TaskQueueSet) StartMain() {
tqs.GetByName(tqs.MainName).Start()
}

func (tqs *TaskQueueSet) Start() {
tqs.m.RLock()
for _, q := range tqs.Queues {
q.Start()
}

tqs.m.RUnlock()
}

func (tqs *TaskQueueSet) Add(queue *TaskQueue) {
tqs.m.Lock()
tqs.Queues[queue.Name] = queue
tqs.m.Unlock()
}

func (tqs *TaskQueueSet) NewNamedQueue(name string, handler func(task.Task) TaskResult) {
Expand All @@ -76,6 +83,8 @@ func (tqs *TaskQueueSet) NewNamedQueue(name string, handler func(task.Task) Task
}

func (tqs *TaskQueueSet) GetByName(name string) *TaskQueue {
tqs.m.RLock()
defer tqs.m.RUnlock()
ts, exists := tqs.Queues[name]
if exists {
return ts
Expand Down Expand Up @@ -105,9 +114,9 @@ func (tqs *TaskQueueSet) Iterate(doFn func(queue *TaskQueue)) {
if doFn == nil {
return
}
tqs.m.Lock()
defer tqs.m.Unlock()

tqs.m.RLock()
defer tqs.m.RUnlock()
if len(tqs.Queues) == 0 {
return
}
Expand All @@ -126,13 +135,14 @@ func (tqs *TaskQueueSet) Iterate(doFn func(queue *TaskQueue)) {
}

func (tqs *TaskQueueSet) Remove(name string) {
tqs.m.Lock()
ts, exists := tqs.Queues[name]
if exists {
ts.Stop()
}
tqs.m.Lock()
defer tqs.m.Unlock()

delete(tqs.Queues, name)
tqs.m.Unlock()
}

func (tqs *TaskQueueSet) WaitStopWithTimeout(timeout time.Duration) {
Expand All @@ -145,15 +155,18 @@ func (tqs *TaskQueueSet) WaitStopWithTimeout(timeout time.Duration) {
select {
case <-checkTick.C:
stopped := true
tqs.m.RLock()
for _, q := range tqs.Queues {
if q.Status != "stop" {
stopped = false
break
}
}
tqs.m.RUnlock()
if stopped {
return
}

case <-timeoutTick.C:
return
}
Expand Down

0 comments on commit f37bbec

Please sign in to comment.