diff --git a/pkg/database/database.go b/pkg/database/database.go index 8bec66c..efd7a03 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -3,6 +3,7 @@ package database import ( "encoding/csv" + "io" "net" "net/http" "slices" @@ -19,16 +20,6 @@ type Entry struct { Data []string } -// fetchCsv fetches a CSV file from the given URL and returns its records. -func fetchCsv(url string) ([][]string, error) { - resp, err := http.Get(url) // #nosec G107 - if err != nil { - return nil, err - } - defer resp.Body.Close() - return csv.NewReader(resp.Body).ReadAll() -} - // sanatizeData trims the leading and trailing spaces from the given strings. func sanatizeData(data []string) []string { sanitized := make([]string, len(data)) @@ -78,9 +69,9 @@ type Database struct { } // NewDatabase creates a new database from the given URL. -func NewDatabase(url string) (*Database, error) { +func NewDatabase(reader io.Reader) (*Database, error) { // Records are the raw data from the CSV file. - records, err := fetchCsv(url) + records, err := csv.NewReader(reader).ReadAll() if err != nil { return nil, err } @@ -98,6 +89,16 @@ func NewDatabase(url string) (*Database, error) { return &Database{entries: entries}, nil } +// NewDatabaseURL creates a new database from the given URL. +func NewDatabaseURL(url string) (*Database, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return NewDatabase(resp.Body) +} + // Find returns the data associated with the entry that contains the given IP. // If the IP is not found, nil is returned. func (db *Database) Find(ip net.IP) []string { diff --git a/pkg/database/resolver.go b/pkg/database/resolver.go index d287ef9..be95d3d 100644 --- a/pkg/database/resolver.go +++ b/pkg/database/resolver.go @@ -34,22 +34,22 @@ type Resolver struct { // NewResolver creates a new IP resolver. func NewResolver() (*Resolver, error) { - countryDBv4, err := NewDatabase(countryIPv4URL) + countryDBv4, err := NewDatabaseURL(countryIPv4URL) if err != nil { return nil, err } - countryDBv6, err := NewDatabase(countryIPv6URL) + countryDBv6, err := NewDatabaseURL(countryIPv6URL) if err != nil { return nil, err } - asnDBv4, err := NewDatabase(asnIPv4URL) + asnDBv4, err := NewDatabaseURL(asnIPv4URL) if err != nil { return nil, err } - asnDBv6, err := NewDatabase(asnIPv6URL) + asnDBv6, err := NewDatabaseURL(asnIPv6URL) if err != nil { return nil, err }