From 6c3e36f4f98807801b50c0ca7f94db02f744b75f Mon Sep 17 00:00:00 2001 From: Matt Heon Date: Wed, 12 Feb 2025 15:46:05 -0500 Subject: [PATCH] Add SyncMap package and use it for graph stop/remove This greatly simplifies the locking around these two functions, and things end up looking a lot more elegant. This should prevent the race flakes we were seeing before. Fixes #25289 Signed-off-by: Matt Heon --- libpod/container_graph.go | 54 ++++++++++---------------- pkg/syncmap/syncmap.go | 82 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 33 deletions(-) create mode 100644 pkg/syncmap/syncmap.go diff --git a/libpod/container_graph.go b/libpod/container_graph.go index 25f2489cd8..1c8078d894 100644 --- a/libpod/container_graph.go +++ b/libpod/container_graph.go @@ -11,6 +11,7 @@ import ( "github.com/containers/podman/v5/libpod/define" "github.com/containers/podman/v5/pkg/parallel" + "github.com/containers/podman/v5/pkg/syncmap" "github.com/sirupsen/logrus" ) @@ -290,18 +291,16 @@ func startNode(ctx context.Context, node *containerNode, setError bool, ctrError // Contains all details required for traversing the container graph. type nodeTraversal struct { - // Protects reads and writes to the two maps. - lock sync.Mutex // Optional. but *MUST* be locked. // Should NOT be changed once a traversal is started. pod *Pod // Function to execute on the individual container being acted on. // Should NOT be changed once a traversal is started. actionFunc func(ctr *Container, pod *Pod) error - // Shared list of errors for all containers currently acted on. - ctrErrors map[string]error - // Shared list of what containers have been visited. - ctrsVisited map[string]bool + // Shared set of errors for all containers currently acted on. + ctrErrors *syncmap.SyncMap[string, error] + // Shared set of what containers have been visited. + ctrsVisited *syncmap.SyncMap[string, bool] } // Perform a traversal of the graph in an inwards direction - meaning from nodes @@ -311,9 +310,7 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr node.lock.Lock() // If we already visited this node, we're done. - nodeDetails.lock.Lock() - visited := nodeDetails.ctrsVisited[node.id] - nodeDetails.lock.Unlock() + visited := nodeDetails.ctrsVisited.Exists(node.id) if visited { node.lock.Unlock() return @@ -322,10 +319,8 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr // Someone who depends on us failed. // Mark us as failed and recurse. if setError { - nodeDetails.lock.Lock() - nodeDetails.ctrsVisited[node.id] = true - nodeDetails.ctrErrors[node.id] = fmt.Errorf("a container that depends on container %s could not be stopped: %w", node.id, define.ErrCtrStateInvalid) - nodeDetails.lock.Unlock() + nodeDetails.ctrsVisited.Put(node.id, true) + nodeDetails.ctrErrors.Put(node.id, fmt.Errorf("a container that depends on container %s could not be stopped: %w", node.id, define.ErrCtrStateInvalid)) node.lock.Unlock() @@ -343,9 +338,7 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr for _, dep := range node.dependedOn { // The container that depends on us hasn't been removed yet. // OK to continue on - nodeDetails.lock.Lock() - ok := nodeDetails.ctrsVisited[dep.id] - nodeDetails.lock.Unlock() + ok := nodeDetails.ctrsVisited.Exists(dep.id) if !ok { node.lock.Unlock() return @@ -355,9 +348,7 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr ctrErrored := false if err := nodeDetails.actionFunc(node.container, nodeDetails.pod); err != nil { ctrErrored = true - nodeDetails.lock.Lock() - nodeDetails.ctrErrors[node.id] = err - nodeDetails.lock.Unlock() + nodeDetails.ctrErrors.Put(node.id, err) } // Mark as visited *only after* finished with operation. @@ -367,9 +358,7 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr // Same with the node lock - we don't want to release it until we are // marked as visited. if !ctrErrored { - nodeDetails.lock.Lock() - nodeDetails.ctrsVisited[node.id] = true - nodeDetails.lock.Unlock() + nodeDetails.ctrsVisited.Put(node.id, true) node.lock.Unlock() } @@ -385,9 +374,7 @@ func traverseNodeInwards(node *containerNode, nodeDetails *nodeTraversal, setErr // and perform its operation before it was marked failed by the // traverseNodeInwards triggered by this process. if ctrErrored { - nodeDetails.lock.Lock() - nodeDetails.ctrsVisited[node.id] = true - nodeDetails.lock.Unlock() + nodeDetails.ctrsVisited.Put(node.id, true) node.lock.Unlock() } @@ -404,8 +391,8 @@ func stopContainerGraph(ctx context.Context, graph *ContainerGraph, pod *Pod, ti nodeDetails := new(nodeTraversal) nodeDetails.pod = pod - nodeDetails.ctrErrors = make(map[string]error) - nodeDetails.ctrsVisited = make(map[string]bool) + nodeDetails.ctrErrors = syncmap.New[string, error]() + nodeDetails.ctrsVisited = syncmap.New[string, bool]() traversalFunc := func(ctr *Container, pod *Pod) error { ctr.lock.Lock() @@ -452,7 +439,7 @@ func stopContainerGraph(ctx context.Context, graph *ContainerGraph, pod *Pod, ti <-doneChan } - return nodeDetails.ctrErrors, nil + return nodeDetails.ctrErrors.Underlying(), nil } // Remove all containers in the given graph @@ -466,10 +453,10 @@ func removeContainerGraph(ctx context.Context, graph *ContainerGraph, pod *Pod, nodeDetails := new(nodeTraversal) nodeDetails.pod = pod - nodeDetails.ctrErrors = make(map[string]error) - nodeDetails.ctrsVisited = make(map[string]bool) + nodeDetails.ctrErrors = syncmap.New[string, error]() + nodeDetails.ctrsVisited = syncmap.New[string, bool]() - ctrNamedVolumes := make(map[string]*ContainerNamedVolume) + ctrNamedVolumes := syncmap.New[string, *ContainerNamedVolume]() traversalFunc := func(ctr *Container, pod *Pod) error { ctr.lock.Lock() @@ -480,7 +467,7 @@ func removeContainerGraph(ctx context.Context, graph *ContainerGraph, pod *Pod, } for _, vol := range ctr.config.NamedVolumes { - ctrNamedVolumes[vol.Name] = vol + ctrNamedVolumes.Put(vol.Name, vol) } if pod != nil && pod.state.InfraContainerID == ctr.ID() { @@ -524,5 +511,6 @@ func removeContainerGraph(ctx context.Context, graph *ContainerGraph, pod *Pod, <-doneChan } - return ctrNamedVolumes, nodeDetails.ctrsVisited, nodeDetails.ctrErrors, nil + // Safe to use Underlying as the SyncMap passes out of scope as we return + return ctrNamedVolumes.Underlying(), nodeDetails.ctrsVisited.Underlying(), nodeDetails.ctrErrors.Underlying(), nil } diff --git a/pkg/syncmap/syncmap.go b/pkg/syncmap/syncmap.go new file mode 100644 index 0000000000..9439fab487 --- /dev/null +++ b/pkg/syncmap/syncmap.go @@ -0,0 +1,82 @@ +package syncmap + +import ( + "maps" + "sync" +) + +// A SyncMap is a map of a string to a generified value which is locked for safe +// access from multiple threads. +// It is effectively a generic version of Golang's standard library sync.Map. +// Admittedly, that has optimizations for multithreading performance that we do +// not here; thus, SyncMap should not be used in truly performance sensitive +// areas, but places where code cleanliness is more important than raw +// performance. +// SyncMap should always be passed by reference, not by value, to ensure thread +// safety is maintained. +type SyncMap[K comparable, V any] struct { + data map[K]V + lock sync.Mutex +} + +// NewSyncMap generates a new, empty SyncMap +func New[K comparable, V any]() *SyncMap[K,V] { + toReturn := new(SyncMap[K,V]) + toReturn.data = make(map[K]V) + + return toReturn +} + +// Put adds an entry into the map +func (m *SyncMap[K, V]) Put(key K, value V) { + m.lock.Lock() + defer m.lock.Unlock() + + m.data[key] = value +} + +// Get retrieves an entry from the map. +// Semantic match Golang map semantics - the bool represents whether the key +// exists, and the empty value of T will be returned if the key does not exist. +func (m *SyncMap[K, V]) Get(key K) (V, bool) { + m.lock.Lock() + defer m.lock.Unlock() + + value, exists := m.data[key] + + return value, exists +} + +// Exists returns true if a key exists in the map. +func (m *SyncMap[K, V]) Exists(key K) bool { + m.lock.Lock() + defer m.lock.Unlock() + + _, ok := m.data[key] + + return ok +} + +// Delete removes an entry from the map. +func (m *SyncMap[K, V]) Delete(key K) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.data, key) +} + +// ToMap returns a shallow copy of the underlying data of the SyncMap. +func (m *SyncMap[K, V]) ToMap() map[K]V { + m.lock.Lock() + defer m.lock.Unlock() + + return maps.Clone(m.data) +} + +// Underlying returns a reference to the underlying storage of the SyncMap. +// Once Underlying has been called, the SyncMap is NO LONGER THREAD SAFE. +// If thread safety is still required, the shallow-copy offered by ToMap() +// should be used instead. +func (m *SyncMap[K, V]) Underlying() map[K]V { + return m.data +}