Skip to content

Commit

Permalink
refactor: rename some packages and functions (#41)
Browse files Browse the repository at this point in the history
* refactor: rename `schema` to `config`

* refactor: rename `schema.ReadFile` to `config.LoadConfig`

* chore: remove unused files

* chore: fix linting error

* chore: rename `ReservedAS0` to `AS0`

* chore: inline `Any` function

* chore: remove extra space

* refactor: rename `database` to `iprange` and `utils` to `iputils`

* chore: update package description
  • Loading branch information
danroc authored Nov 22, 2024
1 parent 97605be commit f298ee9
Show file tree
Hide file tree
Showing 18 changed files with 218 additions and 602 deletions.
22 changes: 11 additions & 11 deletions cmd/geoblock/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (

log "github.com/sirupsen/logrus"

"github.com/danroc/geoblock/pkg/database"
"github.com/danroc/geoblock/pkg/config"
"github.com/danroc/geoblock/pkg/iprange"
"github.com/danroc/geoblock/pkg/rules"
"github.com/danroc/geoblock/pkg/schema"
"github.com/danroc/geoblock/pkg/server"
)

Expand Down Expand Up @@ -41,7 +41,7 @@ func getOptions() *appOptions {
}

// autoUpdate updates the databases at regular intervals.
func autoUpdate(resolver *database.Resolver) {
func autoUpdate(resolver *iprange.Resolver) {
for range time.Tick(autoUpdateInterval) {
if err := resolver.Update(); err != nil {
log.Errorf("Cannot update databases: %v", err)
Expand All @@ -59,15 +59,15 @@ func hasChanged(a, b os.FileInfo) bool {

// autoReload watches the configuration file for changes and updates the engine
// when it happens.
func autoReload(engine *rules.Engine, config string) {
prevStat, err := os.Stat(config)
func autoReload(engine *rules.Engine, path string) {
prevStat, err := os.Stat(path)
if err != nil {
log.Errorf("Cannot watch configuration file: %v", err)
return
}

for range time.Tick(autoReloadInterval) {
stat, err := os.Stat(config)
stat, err := os.Stat(path)
if err != nil {
log.Errorf("Cannot watch configuration file: %v", err)
continue
Expand All @@ -78,13 +78,13 @@ func autoReload(engine *rules.Engine, config string) {
}
prevStat = stat

config, err := schema.ReadFile(config)
cfg, err := config.LoadConfig(path)
if err != nil {
log.Errorf("Cannot read configuration file: %v", err)
continue
}

engine.UpdateConfig(&config.AccessControl)
engine.UpdateConfig(&cfg.AccessControl)
log.Info("Configuration reloaded")
}
}
Expand All @@ -110,20 +110,20 @@ func main() {
configureLogger(options.logLevel)

log.Info("Loading configuration file")
config, err := schema.ReadFile(options.configPath)
cfg, err := config.LoadConfig(options.configPath)
if err != nil {
log.Fatalf("Cannot read configuration file: %v", err)
}

log.Info("Initializing database resolver")
resolver, err := database.NewResolver()
resolver, err := iprange.NewResolver()
if err != nil {
log.Fatalf("Cannot initialize database resolver: %v", err)
}

var (
address = ":" + options.serverPort
engine = rules.NewEngine(&config.AccessControl)
engine = rules.NewEngine(&cfg.AccessControl)
server = server.NewServer(address, engine, resolver)
)

Expand Down
45 changes: 45 additions & 0 deletions pkg/config/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Package config contains the schema and helper functions to work with the
// configuration file.
package config

import (
"os"

"github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3"
)

// isCIDRField checks if the value of the given field is a valid CIDR.
func isCIDRField(field validator.FieldLevel) bool {
cidr, ok := field.Field().Interface().(CIDR)
if !ok || cidr.IPNet == nil {
return false
}
return true
}

// read reads the configuration from the giver bytes slice.
func read(data []byte) (*Configuration, error) {
var config Configuration
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, err
}

validate := validator.New()
validate.RegisterValidation("cidr", isCIDRField) // #nosec G104

if err := validate.Struct(config); err != nil {
return nil, err
}

return &config, nil
}

// LoadConfig reads the configuration from the given file.
func LoadConfig(filename string) (*Configuration, error) {
data, err := os.ReadFile(filename) // #nosec G304
if err != nil {
return nil, err
}
return read(data)
}
2 changes: 1 addition & 1 deletion pkg/schema/schema.go → pkg/config/schema.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package schema
package config

import "net"

Expand Down
16 changes: 8 additions & 8 deletions pkg/database/database.go → pkg/iprange/database.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Package database provides a database of IP ranges and their associated data.
package database
// Package iprange provides a database of IP ranges and their associated data.
package iprange

import (
"encoding/csv"
Expand All @@ -10,7 +10,7 @@ import (
"strings"
"sync/atomic"

"github.com/danroc/geoblock/pkg/utils"
"github.com/danroc/geoblock/pkg/utils/iputils"
)

// Entry represents an IP range and its associated data.
Expand Down Expand Up @@ -40,11 +40,11 @@ func parseRecords(records [][]string) ([]Entry, error) {
)

if startIP == nil {
return nil, &utils.ErrInvalidIP{Address: record[0]}
return nil, &iputils.ErrInvalidIP{Address: record[0]}
}

if endIP == nil {
return nil, &utils.ErrInvalidIP{Address: record[1]}
return nil, &iputils.ErrInvalidIP{Address: record[1]}
}

entries = append(entries, Entry{
Expand Down Expand Up @@ -86,7 +86,7 @@ func (db *Database) Update(reader io.Reader) error {
// The entries must be sorted by their start IP to allow binary search. The
// sort is done in-place.
slices.SortFunc(entries, func(a, b Entry) int {
return utils.CompareIP(a.StartIP, b.StartIP)
return iputils.CompareIP(a.StartIP, b.StartIP)
})

// This atomically updates the database entries.
Expand All @@ -113,7 +113,7 @@ func (db *Database) Find(ip net.IP) []string {
// entry whose start-IP is less than or equal to the given IP) because it
// would return the first entry of the list in most of the cases.
i := sort.Search(len(entries), func(i int) bool {
return utils.CompareIP(entries[i].StartIP, ip) > 0
return iputils.CompareIP(entries[i].StartIP, ip) > 0
})

// Not found: the start-IP of the first entry is greater than the given IP.
Expand All @@ -127,7 +127,7 @@ func (db *Database) Find(ip net.IP) []string {
// From the search, it's guaranteed that the start-IP of the match is less
// than or equal to the given IP. So, the IP only needs to be compared to
// the end-IP of the match.
if utils.CompareIP(ip, match.EndIP) <= 0 {
if iputils.CompareIP(ip, match.EndIP) <= 0 {
return match.Data
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/database/database_test.go → pkg/iprange/database_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package database_test
package iprange_test

import (
"errors"
"net"
"strings"
"testing"

"github.com/danroc/geoblock/pkg/database"
"github.com/danroc/geoblock/pkg/iprange"
)

const (
Expand Down Expand Up @@ -49,7 +49,7 @@ func TestNewDatabase(t *testing.T) {

for _, test := range tests {
reader := strings.NewReader(test.data)
err := database.NewDatabase().Update(reader)
err := iprange.NewDatabase().Update(reader)
if test.err && err == nil {
t.Errorf("%s: expected an error but got nil", test.name)
}
Expand All @@ -67,15 +67,15 @@ func (r *errorReader) Read(p []byte) (n int, err error) {

func TestNewDatabaseReadErr(t *testing.T) {
reader := &errorReader{}
err := database.NewDatabase().Update(reader)
err := iprange.NewDatabase().Update(reader)
if err == nil {
t.Fatalf("Expected an error but got nil")
}
}

func TestFind(t *testing.T) {
reader := strings.NewReader(csvData1)
db := database.NewDatabase()
db := iprange.NewDatabase()
if err := db.Update(reader); err != nil {
t.Fatalf("Expected no error but got %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/database/export_test.go → pkg/iprange/export_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package database
package iprange

var (
StrToASN = strToASN
Expand Down
12 changes: 6 additions & 6 deletions pkg/database/resolver.go → pkg/iprange/resolver.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package database
package iprange

import (
"errors"
"net"
"net/http"
"strconv"

"github.com/danroc/geoblock/pkg/utils"
"github.com/danroc/geoblock/pkg/utils/iputils"
)

// URLs of the CSV IP location databases.
Expand All @@ -17,8 +17,8 @@ const (
ASNIPv6URL = "https://cdn.jsdelivr.net/npm/@ip-location-db/geolite2-asn/geolite2-asn-ipv6.csv"
)

// ReservedAS0 is the ASN used when the ASN is unknown. Its value is 0.
const ReservedAS0 uint32 = 0
// AS0 represents the default ASN value for unknown addresses.
const AS0 uint32 = 0

// Resolution contains the result of resolving an IP address.
type Resolution struct {
Expand Down Expand Up @@ -101,7 +101,7 @@ func strIndex(data []string, index int) string {
func strToASN(s string) uint32 {
asn, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return ReservedAS0
return AS0
}
return uint32(asn)
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func resolve(ip net.IP, countryDB *Database, asnDB *Database) *Resolution {
// The Organization field is present for informational purposes only. It is not
// used by the rules engine.
func (r *Resolver) Resolve(ip net.IP) *Resolution {
if utils.IsIPv4(ip) {
if iputils.IsIPv4(ip) {
return resolve(ip, r.countryDBv4, r.asnDBv4)
}
return resolve(ip, r.countryDBv6, r.asnDBv6)
Expand Down
30 changes: 15 additions & 15 deletions pkg/database/resolver_test.go → pkg/iprange/resolver_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package database_test
package iprange_test

import (
"bytes"
Expand All @@ -7,7 +7,7 @@ import (
"net/http"
"testing"

"github.com/danroc/geoblock/pkg/database"
"github.com/danroc/geoblock/pkg/iprange"
)

func TestStrIndex(t *testing.T) {
Expand All @@ -26,7 +26,7 @@ func TestStrIndex(t *testing.T) {

for _, tt := range tests {
t.Run("", func(t *testing.T) {
result := database.StrIndex(tt.data, tt.index)
result := iprange.StrIndex(tt.data, tt.index)
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
Expand All @@ -42,14 +42,14 @@ func TestStrToASN(t *testing.T) {
{"12345", 12345},
{"0", 0},
{"4294967295", 4294967295},
{"invalid", database.ReservedAS0},
{"", database.ReservedAS0},
{"-1", database.ReservedAS0},
{"invalid", iprange.AS0},
{"", iprange.AS0},
{"-1", iprange.AS0},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := database.StrToASN(tt.input)
result := iprange.StrToASN(tt.input)
if result != tt.expected {
t.Errorf("got %d, want %d", result, tt.expected)
}
Expand All @@ -69,10 +69,10 @@ func newDummyRT() http.RoundTripper {
return &mockRT{
respond: func(req *http.Request) (*http.Response, error) {
body := map[string]string{
database.CountryIPv4URL: "1.0.0.0,1.0.2.2,US\n1.1.0.0,1.1.2.2,FR\n",
database.CountryIPv6URL: "1:0::,1:1::,US\n1:2::,1:3::,FR\n",
database.ASNIPv4URL: "1.0.0.0,1.0.2.2,1,Test1\n1.1.0.0,1.1.2.2,2,Test2\n",
database.ASNIPv6URL: "1:0::,1:1::,3,Test3\n1:2::,1:3::,4,Test4\n",
iprange.CountryIPv4URL: "1.0.0.0,1.0.2.2,US\n1.1.0.0,1.1.2.2,FR\n",
iprange.CountryIPv6URL: "1:0::,1:1::,US\n1:2::,1:3::,FR\n",
iprange.ASNIPv4URL: "1.0.0.0,1.0.2.2,1,Test1\n1.1.0.0,1.1.2.2,2,Test2\n",
iprange.ASNIPv6URL: "1:0::,1:1::,3,Test3\n1:2::,1:3::,4,Test4\n",
}[req.URL.String()]

return &http.Response{
Expand Down Expand Up @@ -100,7 +100,7 @@ func withRT(rt http.RoundTripper, f func()) {

func TestNewResolverError(t *testing.T) {
withRT(newErrRT(), func() {
_, err := database.NewResolver()
_, err := iprange.NewResolver()
if err == nil {
t.Fatal("expected an error, got nil")
}
Expand All @@ -117,12 +117,12 @@ func TestResolverResolve(t *testing.T) {
}{
{"1.0.1.1", "US", "Test1", 1},
{"1.1.1.1", "FR", "Test2", 2},
{"1.2.1.1", "", "", database.ReservedAS0},
{"1.2.1.1", "", "", iprange.AS0},
{"1:0::", "US", "Test3", 3},
{"1:2::", "FR", "Test4", 4},
{"1:4::", "", "", database.ReservedAS0},
{"1:4::", "", "", iprange.AS0},
}
r, _ := database.NewResolver()
r, _ := iprange.NewResolver()
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := r.Resolve(net.ParseIP(tt.ip))
Expand Down
Loading

0 comments on commit f298ee9

Please sign in to comment.