Skip to content

Commit

Permalink
Reimplement Map.All without using Map.Iter
Browse files Browse the repository at this point in the history
Implement the native iterate code inside All. All should be the
preferred way to iterate over Map when using go1.23. This
reimplementation reduces its overhead.

goos: darwin
goarch: amd64
pkg: github.com/aristanetworks/gomap
cpu: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
       │  before.txt  │              after.txt              │
       │    sec/op    │   sec/op     vs base                │
All-16   157.60n ± 2%   56.70n ± 1%  -64.02% (p=0.000 n=10)

       │ before.txt │             after.txt              │
       │    B/op    │   B/op     vs base                 │
All-16   96.00 ± 0%   0.00 ± 0%  -100.00% (p=0.000 n=10)

       │ before.txt │              after.txt              │
       │ allocs/op  │ allocs/op   vs base                 │
All-16   1.000 ± 0%   0.000 ± 0%  -100.00% (p=0.000 n=10)
  • Loading branch information
aaronbee committed Sep 18, 2024
1 parent 941e0d8 commit d348323
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 7 deletions.
106 changes: 99 additions & 7 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,112 @@ import "iter"

// All returns an iterator over key-value pairs from m.
func (m *Map[K, E]) All() iter.Seq2[K, E] {
return func(yield func(K, E) bool) {
for it := m.Iter(); it.Next(); {
if !yield(it.Key(), it.Elem()) {
return m.iterFunc
}

func (m *Map[K, E]) iterFunc(yield func(k K, e E) bool) {
if m == nil || m.count == 0 {
return
}
// Remember we have an iterator.
// Can run concurrently with another m.Iter().
atomicOr(&m.flags, iterator|oldIterator)

var (
r = rand64()
buckets = m.buckets
startBucket = int(r & m.bucketMask())
nextBucket = startBucket
b *bucket[K, E]
checkBucket int
offset = uint8(r >> (64 - bucketCntBits))
wrapped bool
)

for {
if b == nil {
if nextBucket == startBucket && wrapped {
return
}

if m.growing() && len(buckets) == len(m.buckets) {
// Iterator was started in the middle of a grow, and the grow isn't done yet.
// If the bucket we're looking at hasn't been filled in yet (i.e. the old
// bucket hasn't been evacuated) then we need to iterate through the old
// bucket and only return the ones that will be migrated to this bucket.
oldbucket := uint64(nextBucket) & m.oldbucketmask()
b = &(*m.oldbuckets)[oldbucket]
if !evacuated(b) {
checkBucket = nextBucket
} else {
b = &buckets[nextBucket]
checkBucket = noCheck
}
} else {
b = &buckets[nextBucket]
checkBucket = noCheck
}
nextBucket++
if nextBucket == len(buckets) {
nextBucket = 0
wrapped = true
}
}

for i := uint8(0); i < bucketCnt; i++ {
offi := (i + offset) & (bucketCnt - 1)
if isEmpty(b.tophash[offi]) || b.tophash[offi] == evacuatedEmpty {
// TODO: emptyRest is hard to use here, as we start iterating
// in the middle of a bucket. It's feasible, just tricky.
continue
}
k := b.keys[offi]
if checkBucket != noCheck && !m.sameSizeGrow() {
// Special case: iterator was started during a grow to a larger size
// and the grow is not done yet. We're working on a bucket whose
// oldbucket has not been evacuated yet. Or at least, it wasn't
// evacuated when we started the bucket. So we're iterating
// through the oldbucket, skipping any keys that will go
// to the other new bucket (each oldbucket expands to two
// buckets during a grow).
// If the item in the oldbucket is not destined for
// the current new bucket in the iteration, skip it.
hash := m.hash(m.seed, k)
if int(hash&m.bucketMask()) != checkBucket {
continue
}
}
if b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY {
// This is the golden data, we can return it.
if !yield(k, b.elems[offi]) {
return
}
} else {
// The hash table has grown since the iterator was started.
// The golden data for this key is now somewhere else.
// Check the current hash table for the data.
// This code handles the case where the key
// has been deleted, updated, or deleted and reinserted.
// NOTE: we need to regrab the key as it has potentially been
// updated to an equal() but not identical key (e.g. +0.0 vs -0.0).
rk, re := m.mapaccessK(k)
if rk == nil {
continue // key has been deleted
}
if !yield(*rk, *re) {
return
}
}
}
b = b.overflow
}
}

// Keys returns an iterator over keys in m.
func (m *Map[K, E]) Keys() iter.Seq[K] {
return func(yield func(K) bool) {
for it := m.Iter(); it.Next(); {
if !yield(it.Key()) {
for k := range m.All() {
if !yield(k) {
return
}
}
Expand All @@ -35,8 +127,8 @@ func (m *Map[K, E]) Keys() iter.Seq[K] {
// Values returns an iterator over values in m.
func (m *Map[K, E]) Values() iter.Seq[E] {
return func(yield func(E) bool) {
for it := m.Iter(); it.Next(); {
if !yield(it.Elem()) {
for _, v := range m.All() {
if !yield(v) {
return
}
}
Expand Down
187 changes: 187 additions & 0 deletions iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package gomap
import (
"hash/maphash"
"maps"
"sync"
"testing"
)

Expand Down Expand Up @@ -92,3 +93,189 @@ func TestRangeFuncs(t *testing.T) {
}
})
}

func TestGetAllRace(t *testing.T) {
m := NewHint[int, int](100, func(a int, b int) bool { return a == b }, intHash)
for i := 0; i < 100; i++ {
m.Set(i, i)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
for i := 0; i < 100; i++ {
v, ok := m.Get(i)
if !ok || v != i {
t.Errorf("expected: %d got: %d, %t", i, v, ok)
}
}
wg.Done()
}()
wg.Add(1)
go func() {
for i := 0; i < 100; i++ {
v, ok := m.Get(i)
if !ok || v != i {
t.Errorf("expected: %d got: %d, %t", i, v, ok)
}
}
wg.Done()
}()

wg.Add(1)
go func() {
outer:
for i := 0; i < 100; i++ {
for range m.All() {
continue outer
}
t.Error("Should have iterated at least once, but didn't.")
}
wg.Done()
}()
wg.Add(1)
go func() {
outer:
for i := 0; i < 100; i++ {
for range m.All() {
continue outer
}
t.Error("Should have iterated at least once, but didn't.")
}
wg.Done()
}()
wg.Wait()
}

func TestAll(t *testing.T) {
m := New[uint64, uint64](
func(a, b uint64) bool { return a == b },
badIntHash,
)
expected := make(map[uint64]uint64, 9)
for i := uint64(0); i < 9; i++ {
expected[i] = i
m.Set(i, i)
}
for k, v := range m.All() {
e, ok := expected[k]
if !ok {
t.Errorf("unexpected value in m: [%d: %d]", k, v)
continue
}
if e != v {
t.Errorf("wrong value for key %d. Expected: %d Got: %d", k, e, v)
continue
}
delete(expected, k)
}
if len(expected) > 0 {
t.Errorf("Values not found in m: %v", expected)
}
}

func TestAllDuringResize(t *testing.T) {
m := New[uint64, uint64](
func(a, b uint64) bool { return a == b },
badIntHash,
)

// insert numbers that initially hash to the same bucket, but will
// be split into different buckets on resize. Evens will end up in
// bucket[0], odds end up in bucket[1] thanks to badIntHash.
initial := map[uint64]uint64{0: 0, 1: 1, 2: 2, 3: 3}
for k, e := range initial {
m.Set(k, e)
}
additional := map[uint64]uint64{100: 100, 101: 101, 102: 102, 103: 103, 104: 104}

first := true
// start the iter
for k, v := range m.All() {
if first {
first = false
// Add some additional data to cause a resize
for k, e := range additional {
m.Set(k, e)
}
// Remove 1 value that in each of the initial and split
// buckets that we haven't seen yet
if k == 0 {
m.Delete(2)
delete(initial, 2)
} else {
m.Delete(0)
delete(initial, 0)
}
if k == 1 {
m.Delete(3)
delete(initial, 3)
} else {
m.Delete(1)
delete(initial, 1)
}
}
if k != v {
t.Errorf("expected key == elem, but got: %d != %d", k, v)
t.Error(m.debugString())
}
if _, ok := initial[k]; ok {
delete(initial, v)
continue
}
if _, ok := additional[k]; ok {
t.Logf("Saw key from additional: %d", k)
continue
}
t.Errorf("Unexpected value from iter: %d", k)
}
for k := range initial {
t.Errorf("iter missing key: %d", k)
}
}

func TestAllDuringGrow(t *testing.T) {
m := New[uint64, uint64](
func(a, b uint64) bool { return a == b },
badIntHash,
)

// Insert exactly 27 numbers so we end up in the middle of a grow.
expected := make(map[uint64]uint64, 27)
for i := uint64(0); i < 27; i++ {
expected[i] = i
m.Set(i, i)
}

for k, v := range m.All() {
t.Logf("Key: %d", k)
if k != v {
t.Errorf("expected key == elem, but got: %d != %d", k, v)
t.Error(m.debugString())
}

if _, ok := expected[k]; ok {
delete(expected, v)
continue
}
t.Errorf("Unexpected value from iter: %d", k)
}
for k := range expected {
t.Errorf("iter missing key: %d", k)
}
}

func BenchmarkAll(b *testing.B) {
m := New[string, int](
func(a, b string) bool { return a == b },
maphash.String,
KeyElem[string, int]{"one", 1},
KeyElem[string, int]{"two", 2},
KeyElem[string, int]{"three", 3},
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range m.All() {
}
}
}
3 changes: 3 additions & 0 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ search:

// Iter instantiates an Iterator to explore the elements of the Map.
// Ordering is undefined and is intentionally randomized.
//
// Prefer [Map.All] over Iter when using go1.23 or later as it works
// with for-range loops and has less overhead.
func (m *Map[K, E]) Iter() *Iterator[K, E] {
// Iter() is a small function to encourage the compiler to inline
// it into its caller and let `it` be kept on the stack.
Expand Down

0 comments on commit d348323

Please sign in to comment.