From 031581d700fa7dbca3fa5d1e3d0305af4f4c8899 Mon Sep 17 00:00:00 2001 From: Daniel Rocha <68558152+danroc@users.noreply.github.com> Date: Thu, 14 Nov 2024 20:34:14 +0100 Subject: [PATCH] refactor: use `atomic.Value` instead of `sync.RWMutex` (#6) --- pkg/database/database.go | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/pkg/database/database.go b/pkg/database/database.go index 75e5fa4..4429b98 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -8,7 +8,7 @@ import ( "slices" "sort" "strings" - "sync" + "sync/atomic" "github.com/danroc/geoblock/pkg/utils" ) @@ -56,22 +56,16 @@ func parseRecords(records [][]string) ([]Entry, error) { return entries, nil } -// sortEntries sorts the entries by their start IP. -func sortEntries(entries []Entry) { - slices.SortFunc(entries, func(a, b Entry) int { - return utils.CompareIP(a.StartIP, b.StartIP) - }) -} - // Database represents a database of IP ranges. type Database struct { - entries []Entry - mu sync.RWMutex + entries atomic.Value // []Entry } // NewDatabase creates a new database from the given URL. func NewDatabase() *Database { - return &Database{} + db := &Database{} + db.entries.Store([]Entry{}) + return db } // Update updates the database with the data from the given reader. @@ -91,12 +85,12 @@ func (db *Database) Update(reader io.Reader) error { // The entries must be sorted by their start IP to allow binary search. The // sort is done in-place. - sortEntries(entries) + slices.SortFunc(entries, func(a, b Entry) int { + return utils.CompareIP(a.StartIP, b.StartIP) + }) - // Update the database with the new entries. - db.mu.Lock() - db.entries = entries - db.mu.Unlock() + // This atomically updates the database entries. + db.entries.Store(entries) return nil } @@ -104,9 +98,6 @@ func (db *Database) Update(reader io.Reader) error { // Find returns the data associated with the entry that contains the given IP. // If the IP is not found, nil is returned. func (db *Database) Find(ip net.IP) []string { - db.mu.RLock() - defer db.mu.RUnlock() - // If the given IP address is invalid, we return nil to indidate that the // IP cannot be found in the database. It is up to the caller to validate // the IP address before calling this method. @@ -114,12 +105,15 @@ func (db *Database) Find(ip net.IP) []string { return nil } + // Atomically load the database entries. + entries := db.entries.Load().([]Entry) + // Find the first entry whose start-IP is greater than the given IP. The // search cannot be done the other way around (i.e., search for the first // entry whose start-IP is less than or equal to the given IP) because it - // would return the first entry in most of the cases. - i := sort.Search(len(db.entries), func(i int) bool { - return utils.CompareIP(db.entries[i].StartIP, ip) > 0 + // would return the first entry of the list in most of the cases. + i := sort.Search(len(entries), func(i int) bool { + return utils.CompareIP(entries[i].StartIP, ip) > 0 }) // Not found: the start-IP of the first entry is greater than the given IP. @@ -128,7 +122,7 @@ func (db *Database) Find(ip net.IP) []string { } // The last entry whose start-IP is less than or equal to the given IP. - match := db.entries[i-1] + match := entries[i-1] // From the search, it's guaranteed that the start-IP of the match is less // than or equal to the given IP. So, the IP only needs to be compared to