Skip to content

Commit

Permalink
refactor: use atomic.Value instead of sync.RWMutex (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
danroc authored Nov 14, 2024
1 parent fbcaa75 commit 031581d
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"slices"
"sort"
"strings"
"sync"
"sync/atomic"

"github.com/danroc/geoblock/pkg/utils"
)
Expand Down Expand Up @@ -56,22 +56,16 @@ func parseRecords(records [][]string) ([]Entry, error) {
return entries, nil
}

// sortEntries sorts the entries by their start IP.
func sortEntries(entries []Entry) {
slices.SortFunc(entries, func(a, b Entry) int {
return utils.CompareIP(a.StartIP, b.StartIP)
})
}

// Database represents a database of IP ranges.
type Database struct {
entries []Entry
mu sync.RWMutex
entries atomic.Value // []Entry
}

// NewDatabase creates a new database from the given URL.
func NewDatabase() *Database {
return &Database{}
db := &Database{}
db.entries.Store([]Entry{})
return db
}

// Update updates the database with the data from the given reader.
Expand All @@ -91,35 +85,35 @@ func (db *Database) Update(reader io.Reader) error {

// The entries must be sorted by their start IP to allow binary search. The
// sort is done in-place.
sortEntries(entries)
slices.SortFunc(entries, func(a, b Entry) int {
return utils.CompareIP(a.StartIP, b.StartIP)
})

// Update the database with the new entries.
db.mu.Lock()
db.entries = entries
db.mu.Unlock()
// This atomically updates the database entries.
db.entries.Store(entries)

return nil
}

// Find returns the data associated with the entry that contains the given IP.
// If the IP is not found, nil is returned.
func (db *Database) Find(ip net.IP) []string {
db.mu.RLock()
defer db.mu.RUnlock()

// If the given IP address is invalid, we return nil to indidate that the
// IP cannot be found in the database. It is up to the caller to validate
// the IP address before calling this method.
if ip == nil {
return nil
}

// Atomically load the database entries.
entries := db.entries.Load().([]Entry)

// Find the first entry whose start-IP is greater than the given IP. The
// search cannot be done the other way around (i.e., search for the first
// entry whose start-IP is less than or equal to the given IP) because it
// would return the first entry in most of the cases.
i := sort.Search(len(db.entries), func(i int) bool {
return utils.CompareIP(db.entries[i].StartIP, ip) > 0
// would return the first entry of the list in most of the cases.
i := sort.Search(len(entries), func(i int) bool {
return utils.CompareIP(entries[i].StartIP, ip) > 0
})

// Not found: the start-IP of the first entry is greater than the given IP.
Expand All @@ -128,7 +122,7 @@ func (db *Database) Find(ip net.IP) []string {
}

// The last entry whose start-IP is less than or equal to the given IP.
match := db.entries[i-1]
match := entries[i-1]

// From the search, it's guaranteed that the start-IP of the match is less
// than or equal to the given IP. So, the IP only needs to be compared to
Expand Down

0 comments on commit 031581d

Please sign in to comment.