Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow wildcard domains in config #42

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cmd/geoblock/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package main

import (
"bytes"
"os"
"time"

Expand Down Expand Up @@ -51,6 +52,15 @@ func autoUpdate(resolver *iprange.Resolver) {
}
}

// loadConfig reads the configuration file from the given path and returns it.
func loadConfig(path string) (*config.Configuration, error) {
file, err := os.ReadFile(path) // #nosec G304
if err != nil {
return nil, err
}
return config.ReadConfig(bytes.NewReader(file))
}

// hasChanged returns true if the two file infos are different. It only checks
// the size and the modification time.
func hasChanged(a, b os.FileInfo) bool {
Expand Down Expand Up @@ -78,7 +88,7 @@ func autoReload(engine *rules.Engine, path string) {
}
prevStat = stat

cfg, err := config.LoadConfig(path)
cfg, err := loadConfig(path)
if err != nil {
log.Errorf("Cannot read configuration file: %v", err)
continue
Expand Down Expand Up @@ -110,7 +120,7 @@ func main() {
configureLogger(options.logLevel)

log.Info("Loading configuration file")
cfg, err := config.LoadConfig(options.configPath)
cfg, err := loadConfig(options.configPath)
if err != nil {
log.Fatalf("Cannot read configuration file: %v", err)
}
Expand Down
26 changes: 21 additions & 5 deletions pkg/config/loader.go → pkg/config/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
package config

import (
"os"
"io"
"regexp"

"github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3"
)

// DomainNameRegex matches a valid domain name as per RFC 1035. It also allows
// labels to be a single `*` wildcard.
var domainNameRegex = regexp.MustCompile(
`^(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(\.(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?))*$`,
)

func isDomainNameField(field validator.FieldLevel) bool {
domain, ok := field.Field().Interface().(string)
if !ok {
return false
}
return domainNameRegex.MatchString(domain)
}

// isCIDRField checks if the value of the given field is a valid CIDR.
func isCIDRField(field validator.FieldLevel) bool {
cidr, ok := field.Field().Interface().(CIDR)
Expand All @@ -26,7 +41,8 @@ func read(data []byte) (*Configuration, error) {
}

validate := validator.New()
validate.RegisterValidation("cidr", isCIDRField) // #nosec G104
validate.RegisterValidation("cidr", isCIDRField) // #nosec G104
validate.RegisterValidation("domain", isDomainNameField) // #nosec G104

if err := validate.Struct(config); err != nil {
return nil, err
Expand All @@ -35,9 +51,9 @@ func read(data []byte) (*Configuration, error) {
return &config, nil
}

// LoadConfig reads the configuration from the given file.
func LoadConfig(filename string) (*Configuration, error) {
data, err := os.ReadFile(filename) // #nosec G304
// ReadConfig reads the configuration from the given reader and returns it.
func ReadConfig(reader io.Reader) (*Configuration, error) {
data, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
Expand Down
217 changes: 217 additions & 0 deletions pkg/config/reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package config_test

import (
"errors"
"net"
"reflect"
"strings"
"testing"

"github.com/danroc/geoblock/pkg/config"
)

const validConfig = `
access_control:
default_policy: allow
rules:
- networks:
- "10.0.0.0/8"
- "127.0.0.0/8"
domains:
- "example.com"
- "*.example.com"
methods:
- GET
- POST
countries:
- US
- FR
autonomous_systems:
- 1234
- 5678
policy: allow

- policy: deny
`

const invalidLeadingDot = `
access_control:
default_policy: allow
rules:
- domains:
- ".example.com"
policy: allow
`

const invalidWildcardLocation = `
access_control:
default_policy: allow
rules:
- domains:
- "*example.com"
policy: allow
`

const invalidDomainChar = `
access_control:
default_policy: allow
rules:
- domains:
- "example?.com"
policy: allow
`

const invalidLeadingDash = `
access_control:
default_policy: allow
rules:
- domains:
- "-example.com"
policy: allow
`

const invalidTrailingDash = `
access_control:
default_policy: allow
rules:
- domains:
- "example-.com"
policy: allow
`

const invalidDomainString = `
access_control:
default_policy: allow
rules:
- domains:
- false
policy: allow
`

const invalidNetworkString = `
access_control:
default_policy: allow
rules:
- networks:
- "invalid"
policy: allow
`

const invalidNetworkNumber = `
access_control:
default_policy: allow
rules:
- networks:
- 10
policy: allow
`

const invalidNetworkRange = `
access_control:
default_policy: allow
rules:
- networks:
- 300.300.300.300/50
policy: allow
`

func TestReadConfigValid(t *testing.T) {
tests := []struct {
name string
data string
expected *config.Configuration
}{
{
"valid configuration",
validConfig,
&config.Configuration{
AccessControl: config.AccessControl{
DefaultPolicy: "allow",
Rules: []config.AccessControlRule{
{
Policy: "allow",
Networks: []config.CIDR{
{
IPNet: &net.IPNet{
IP: net.IP{10, 0, 0, 0},
Mask: net.CIDRMask(8, 32),
},
},
{
IPNet: &net.IPNet{
IP: net.IP{127, 0, 0, 0},
Mask: net.CIDRMask(8, 32),
},
},
},
Domains: []string{
"example.com",
"*.example.com",
},
Methods: []string{"GET", "POST"},
Countries: []string{"US", "FR"},
AutonomousSystems: []uint32{1234, 5678},
},
{
Policy: "deny",
Networks: nil,
Domains: nil,
Methods: nil,
Countries: nil,
AutonomousSystems: nil,
},
},
},
},
},
}

for _, test := range tests {
reader := strings.NewReader(test.data)
cfg, err := config.ReadConfig(reader)
if err != nil {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
if !reflect.DeepEqual(*cfg, *test.expected) {
t.Errorf("%s: expected %v, got %v", test.name, test.expected, cfg)
}
}
}

func TestReadConfigErr(t *testing.T) {
tests := []struct {
name string
data string
}{
{"invalid leading dot", invalidLeadingDot},
{"invalid wildcard location", invalidWildcardLocation},
{"invalid domain character", invalidDomainChar},
{"invalid leading dash", invalidLeadingDash},
{"invalid trailing dash", invalidTrailingDash},
{"invalid network string", invalidNetworkString},
{"invalid network number", invalidNetworkNumber},
{"invalid network range", invalidNetworkRange},
{"invalid domain string", invalidDomainString},
}

for _, test := range tests {
reader := strings.NewReader(test.data)
_, err := config.ReadConfig(reader)
if err == nil {
t.Errorf("%s: expected an error but got nil", test.name)
}
}
}

type errReader struct{}

func (r *errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("read error")
}

func TestReadConfigErrReader(t *testing.T) {
_, err := config.ReadConfig(&errReader{})
if err == nil {
t.Error("expected an error but got nil")
}
}
2 changes: 1 addition & 1 deletion pkg/config/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (n *CIDR) UnmarshalYAML(unmarshal func(interface{}) error) error {
type AccessControlRule struct {
Policy string `yaml:"policy" validate:"required,oneof=allow deny"`
Networks []CIDR `yaml:"networks,omitempty" validate:"dive,cidr"`
Domains []string `yaml:"domains,omitempty" validate:"dive,fqdn"`
Domains []string `yaml:"domains,omitempty" validate:"dive,domain"`
Methods []string `yaml:"methods,omitempty" validate:"dive,oneof=GET HEAD POST PUT DELETE PATCH"`
Countries []string `yaml:"countries,omitempty" validate:"dive,iso3166_1_alpha2"`
AutonomousSystems []uint32 `yaml:"autonomous_systems,omitempty" validate:"dive,numeric"`
Expand Down