From 324281a7772d25b072dc5de2997d6e85537c3d1e Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 18:05:17 -0500 Subject: [PATCH 1/6] feat: Add whitelist functionality (#31) Adds a whitelist of IPs which bypasses the CIDR checks. This commit introduces a new `whitelist` directive to the configuration, allowing specific IPs to be excluded from blocking. The whitelist is checked before any CIDR matching. WIP: Still need to integrate with other middleware. --- config.go | 11 ++++++++ middleware.go | 2 +- plugin.go | 7 ++++- utils/ip/ip.go | 30 +++++++++++++++++----- utils/ip/ip_test.go | 12 ++++----- utils/ip/whitelist.go | 59 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 utils/ip/whitelist.go diff --git a/config.go b/config.go index 2d378d2..d5cd09c 100644 --- a/config.go +++ b/config.go @@ -52,6 +52,10 @@ func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } Message := d.Val() m.Message = Message + case "whitelist": + for d.NextArg() { + m.WhiteList = append(m.WhiteList, d.Val()) + } default: return d.Errf("unknown subdirective '%s'", d.Val()) } @@ -131,6 +135,13 @@ func (m *Defender) Validate() error { } } + // Check if the whitelist is valid + for _, ip := range m.WhiteList { + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP address %q in whitelist", ip) + } + } + return nil } diff --git a/middleware.go b/middleware.go index 5ca5fe6..74f3bd5 100644 --- a/middleware.go +++ b/middleware.go @@ -25,7 +25,7 @@ func (m Defender) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht } m.log.Debug("Ranges", zap.Strings("ranges", m.Ranges)) // Check if the client IP is in any of the ranges using the optimized checker - if m.ipChecker.IPInRanges(r.Context(), clientIP) { + if m.ipChecker.ReqAllowed(r.Context(), clientIP) { m.log.Debug("IP is in ranges", zap.String("ip", clientIP.String())) return m.responder.ServeHTTP(w, r, next) } else { diff --git a/plugin.go b/plugin.go index 3020f93..8b8f61f 100644 --- a/plugin.go +++ b/plugin.go @@ -61,6 +61,11 @@ type Defender struct { // Default: All predefined ranges if empty Ranges []string `json:"ranges,omitempty"` + // An optional whitelist of IP addresses to exclude from blocking. If empty, no IPs are whitelisted. + // NOTE: this only supports IP addresses, not ranges. + // Default: [] + WhiteList []string `json:"whitelist,omitempty"` + // Message specifies the custom response message for 'custom' responder type. // Required only when using 'custom' responder. Message string `json:"message,omitempty"` @@ -86,7 +91,7 @@ func (m *Defender) Provision(ctx caddy.Context) error { } // ensure to keep AFTER the ranges are checked (above) - m.ipChecker = ip.NewIPChecker(m.Ranges, m.log) + m.ipChecker = ip.NewIPChecker(m.Ranges, m.WhiteList, m.log) return nil } diff --git a/utils/ip/ip.go b/utils/ip/ip.go index d1902fa..0c87d49 100644 --- a/utils/ip/ip.go +++ b/utils/ip/ip.go @@ -14,12 +14,13 @@ import ( ) type IPChecker struct { - table *bart.Table[struct{}] - cache *sturdyc.Client[string] - log *zap.Logger + table *bart.Table[struct{}] + cache *sturdyc.Client[string] + whitelist *Whitelist + log *zap.Logger } -func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { +func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPChecker { const ( capacity = 10000 numShards = 10 @@ -30,6 +31,13 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { retryBaseDelay = 10 * time.Millisecond ) + whitelist, err := NewWhitelist(whitelistedIPs) + if err != nil { + log.Warn("Invalid whitelist IP", + zap.Strings("whitelist", whitelistedIPs), + zap.Error(err)) + } + cache := sturdyc.New[string]( capacity, numShards, @@ -45,10 +53,18 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { ) return &IPChecker{ - table: buildTable(cidrRanges, log), - cache: cache, - log: log, + table: buildTable(cidrRanges, log), + cache: cache, + log: log, + whitelist: whitelist, + } +} + +func (c *IPChecker) ReqAllowed(ctx context.Context, clientIP net.IP) bool { + if c.whitelist.Allowed(clientIP.String()) { + return true } + return c.IPInRanges(ctx, clientIP) } func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { diff --git a/utils/ip/ip_test.go b/utils/ip/ip_test.go index 35ae419..b80c8f9 100644 --- a/utils/ip/ip_test.go +++ b/utils/ip/ip_test.go @@ -46,7 +46,7 @@ func TestIPInRanges(t *testing.T) { data.IPRanges = predefinedCIDRs // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) tests := []struct { name string @@ -93,7 +93,7 @@ func TestIPInRanges(t *testing.T) { func TestIPInRangesCache(t *testing.T) { // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") @@ -110,7 +110,7 @@ func TestIPInRangesCache(t *testing.T) { func TestIPInRangesCacheExpiration(t *testing.T) { // Create a new IPChecker with a short cache TTL for testing - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") @@ -130,7 +130,7 @@ func TestIPInRangesCacheExpiration(t *testing.T) { func TestIPInRangesInvalidCIDR(t *testing.T) { // Create a new IPChecker with invalid CIDRs - checker := NewIPChecker(invalidCIDRs, testLogger) + checker := NewIPChecker(invalidCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") @@ -143,7 +143,7 @@ func TestIPInRangesInvalidCIDR(t *testing.T) { func TestIPInRangesInvalidIP(t *testing.T) { // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test invalid IP clientIP := net.IP([]byte{1, 2, 3}) // Invalid IP @@ -216,7 +216,7 @@ func TestPredefinedCIDRGroups(t *testing.T) { return nil })) - checker := NewIPChecker(tt.groups, logger) + checker := NewIPChecker(tt.groups, []string{}, logger) clientIP := net.ParseIP(tt.ip) assert.NotNil(t, clientIP, "Failed to parse IP") diff --git a/utils/ip/whitelist.go b/utils/ip/whitelist.go new file mode 100644 index 0000000..f30cdbf --- /dev/null +++ b/utils/ip/whitelist.go @@ -0,0 +1,59 @@ +package ip + +import ( + "fmt" + "net" +) + +type Whitelist struct { + ips map[string]struct{} +} + +// NewWhitelist initializes a new Whitelist from IP strings. +func NewWhitelist(ipStrings []string) (*Whitelist, error) { + wl := &Whitelist{ + ips: make(map[string]struct{}), + } + for _, ipStr := range ipStrings { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + ip16 := ip.To16() + if ip16 == nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + wl.ips[ip16.String()] = struct{}{} + } + return wl, nil +} + +// Allowed checks if the remote address is in the whitelist. +func (wl *Whitelist) Allowed(remoteAddr string) bool { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + // Handle cases where there's no port + host = remoteAddr + } + + ip := net.ParseIP(host) + if ip == nil { + return false // Invalid IP format + } + + ip16 := ip.To16() + if ip16 == nil { + return false // Shouldn't happen if ParseIP succeeded + } + + _, ok := wl.ips[ip16.String()] + return ok +} + +// Example usage in Caddy middleware: +// func (m YourModule) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { +// if !m.whitelist.Allowed(r.RemoteAddr) { +// return caddyhttp.Error(http.StatusForbidden, nil) +// } +// return next.ServeHTTP(w, r) +// } From 09684fd0c04f154d00899a509957840fdb15cfe7 Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 18:09:59 -0500 Subject: [PATCH 2/6] chore: Format caddyfile config docs --- config.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/config.go b/config.go index d5cd09c..8e3d0fd 100644 --- a/config.go +++ b/config.go @@ -18,11 +18,13 @@ var responderTypes = []string{"block", "garbage", "custom", "ratelimit"} // UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax: // // defender { -// # IP ranges to block -// ranges -// # Custom message to return to the client when using "custom" middleware (optional) -// message -// } +// # IP ranges to block +// ranges +// # Whitelisted IP addresses to allow to bypass ranges (optional) +// whitelist +// # Custom message to return to the client when using "custom" middleware (optional) +// message +// } func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { d.Next() // consume directive name From fc9f5954df092f9d31d0ab73d72d8d010cccc79a Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 18:12:55 -0500 Subject: [PATCH 3/6] docs: add whitelist example --- examples/whitelist/Caddyfile | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 examples/whitelist/Caddyfile diff --git a/examples/whitelist/Caddyfile b/examples/whitelist/Caddyfile new file mode 100644 index 0000000..c24edb2 --- /dev/null +++ b/examples/whitelist/Caddyfile @@ -0,0 +1,15 @@ +{ + auto_https off + order defender after header + debug +} + +:80 { + bind 127.0.0.1 ::1 + # Everything in AWS besides my EC2 instance is blocked from accessing this site. + defender block { + ranges aws + whitelist 169.254.169.254 # my ec2's public IP. + } + respond "This is what a human sees" +} From b178d19805c65a224fcb3e6c4fa95f783b2e771e Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 18:17:59 -0500 Subject: [PATCH 4/6] Fix: Correct middleware logic and add test case (#31) The middleware logic was inverted, causing requests from blocked ranges to be allowed and requests from allowed ranges to be blocked. This commit corrects the logic to ensure that requests from blocked ranges are blocked and requests from allowed ranges or whitelisted IPs are allowed. --- examples/whitelist/Caddyfile | 10 ++++++++++ middleware.go | 5 +++-- utils/ip/ip.go | 3 ++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/whitelist/Caddyfile b/examples/whitelist/Caddyfile index c24edb2..8d19662 100644 --- a/examples/whitelist/Caddyfile +++ b/examples/whitelist/Caddyfile @@ -13,3 +13,13 @@ } respond "This is what a human sees" } + +:81 { + bind 127.0.0.1 ::1 + # My localhost ipv6 is blocked but not my ipv4 + defender block { + ranges private + whitelist 127.0.0.1 + } + respond "This is what a ipv4 human sees" +} diff --git a/middleware.go b/middleware.go index 74f3bd5..533a8b8 100644 --- a/middleware.go +++ b/middleware.go @@ -26,10 +26,11 @@ func (m Defender) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht m.log.Debug("Ranges", zap.Strings("ranges", m.Ranges)) // Check if the client IP is in any of the ranges using the optimized checker if m.ipChecker.ReqAllowed(r.Context(), clientIP) { + m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String())) + + } else { m.log.Debug("IP is in ranges", zap.String("ip", clientIP.String())) return m.responder.ServeHTTP(w, r, next) - } else { - m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String())) } // IP is not in any of the ranges, proceed to the next handler diff --git a/utils/ip/ip.go b/utils/ip/ip.go index 0c87d49..7006a2f 100644 --- a/utils/ip/ip.go +++ b/utils/ip/ip.go @@ -62,9 +62,10 @@ func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPCheck func (c *IPChecker) ReqAllowed(ctx context.Context, clientIP net.IP) bool { if c.whitelist.Allowed(clientIP.String()) { + c.log.Debug("IP is whitelisted", zap.String("ip", clientIP.String())) return true } - return c.IPInRanges(ctx, clientIP) + return !c.IPInRanges(ctx, clientIP) } func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { From 701c0d7689edb469681e988e3056c4acccff9028 Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 19:54:23 -0500 Subject: [PATCH 5/6] fix(middleware): correct IP whitelist logic and refactor implementation (#31) * Fix inverted middleware logic that incorrectly handled blocked/allowed IP ranges * Refactor whitelist package to use netip.Addr for more efficient IP handling * Add comprehensive test coverage for whitelist functionality * Move whitelist code to dedicated package * Update config validation to handle default ranges properly * Rename WhiteList field to Whitelist for consistency --- config.go | 14 +-- config_test.go | 40 ++++++++ plugin.go | 6 +- utils/ip/ip.go | 26 +++--- utils/ip/ip_test.go | 32 +++++-- utils/ip/whitelist.go | 59 ------------ utils/ip/whitelist/whitelist.go | 44 +++++++++ utils/ip/whitelist/whitelist_test.go | 133 +++++++++++++++++++++++++++ 8 files changed, 262 insertions(+), 92 deletions(-) delete mode 100644 utils/ip/whitelist.go create mode 100644 utils/ip/whitelist/whitelist.go create mode 100644 utils/ip/whitelist/whitelist_test.go diff --git a/config.go b/config.go index 8e3d0fd..f4235f8 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/jasonlovesdoggo/caddy-defender/ranges/data" "github.com/jasonlovesdoggo/caddy-defender/responders" + "github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist" "net" "reflect" "slices" @@ -56,7 +57,7 @@ func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { m.Message = Message case "whitelist": for d.NextArg() { - m.WhiteList = append(m.WhiteList, d.Val()) + m.Whitelist = append(m.Whitelist, d.Val()) } default: return d.Errf("unknown subdirective '%s'", d.Val()) @@ -118,10 +119,6 @@ func (m *Defender) Validate() error { if m.responder == nil { return fmt.Errorf("responder not configured") } - if len(m.Ranges) == 0 { - // set the default ranges to be all of the predefined ranges - return fmt.Errorf("no ranges specified, this is required") - } for _, ipRange := range m.Ranges { // Check if the range is a predefined key (e.g., "openai") @@ -138,10 +135,9 @@ func (m *Defender) Validate() error { } // Check if the whitelist is valid - for _, ip := range m.WhiteList { - if net.ParseIP(ip) == nil { - return fmt.Errorf("invalid IP address %q in whitelist", ip) - } + err := whitelist.ValidateWhitelist(m.Whitelist) + if err != nil { + return err } return nil diff --git a/config_test.go b/config_test.go index 3cca4ec..ac3d7f2 100644 --- a/config_test.go +++ b/config_test.go @@ -2,8 +2,13 @@ package caddydefender import ( "encoding/json" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddytest" + "github.com/jasonlovesdoggo/caddy-defender/ranges/data" "github.com/jasonlovesdoggo/caddy-defender/responders" + "maps" + "slices" + "sort" "testing" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" @@ -166,6 +171,7 @@ func TestValidation(t *testing.T) { def := Defender{ RawResponder: "block", Ranges: []string{"10.0.0.0/8"}, + Whitelist: []string{"126.39.0.3"}, responder: &responders.BlockResponder{}, } require.NoError(t, def.Validate()) @@ -186,6 +192,40 @@ func TestValidation(t *testing.T) { } require.ErrorContains(t, def.Validate(), "invalid IP range") }) + + t.Run("invalid whitelist IP", func(t *testing.T) { + def := Defender{ + RawResponder: "block", + Whitelist: []string{"invalid"}, + responder: &responders.BlockResponder{}, + } + require.ErrorContains(t, def.Validate(), "invalid IP address") + }) + + t.Run("Missing ranges", func(t *testing.T) { + def := Defender{ + RawResponder: "block", + responder: &responders.BlockResponder{}, + } + err := def.Provision(caddy.Context{Context: caddy.ActiveContext()}) + if err != nil { + return + } + + // We must sort the ranges to compare them as the order is not guaranteed + sort.Slice(def.Ranges, func(i, j int) bool { + return def.Ranges[i] < def.Ranges[j] + }) + + defaultRanges := slices.Collect(maps.Keys(data.IPRanges)) + + sort.Slice(defaultRanges, func(i, j int) bool { + return defaultRanges[i] < defaultRanges[j] + }) + + require.Equal(t, defaultRanges, def.Ranges) + }) + } func TestDefenderValidation(t *testing.T) { diff --git a/plugin.go b/plugin.go index 8b8f61f..4955c1f 100644 --- a/plugin.go +++ b/plugin.go @@ -64,7 +64,7 @@ type Defender struct { // An optional whitelist of IP addresses to exclude from blocking. If empty, no IPs are whitelisted. // NOTE: this only supports IP addresses, not ranges. // Default: [] - WhiteList []string `json:"whitelist,omitempty"` + Whitelist []string `json:"whitelist,omitempty"` // Message specifies the custom response message for 'custom' responder type. // Required only when using 'custom' responder. @@ -86,12 +86,12 @@ func (m *Defender) Provision(ctx caddy.Context) error { if len(m.Ranges) == 0 { // set the default ranges to be all of the predefined ranges - m.log.Debug("no ranges specified, this is required") + m.log.Debug("no ranges specified, defaulting to all predefined ranges") m.Ranges = slices.Collect(maps.Keys(data.IPRanges)) } // ensure to keep AFTER the ranges are checked (above) - m.ipChecker = ip.NewIPChecker(m.Ranges, m.WhiteList, m.log) + m.ipChecker = ip.NewIPChecker(m.Ranges, m.Whitelist, m.log) return nil } diff --git a/utils/ip/ip.go b/utils/ip/ip.go index 7006a2f..cee823e 100644 --- a/utils/ip/ip.go +++ b/utils/ip/ip.go @@ -3,6 +3,7 @@ package ip import ( "context" "fmt" + Whitelist "github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist" "net" "net/netip" "time" @@ -16,7 +17,7 @@ import ( type IPChecker struct { table *bart.Table[struct{}] cache *sturdyc.Client[string] - whitelist *Whitelist + whitelist *Whitelist.Whitelist log *zap.Logger } @@ -31,7 +32,7 @@ func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPCheck retryBaseDelay = 10 * time.Millisecond ) - whitelist, err := NewWhitelist(whitelistedIPs) + whitelist, err := Whitelist.NewWhitelist(whitelistedIPs) if err != nil { log.Warn("Invalid whitelist IP", zap.Strings("whitelist", whitelistedIPs), @@ -61,15 +62,7 @@ func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPCheck } func (c *IPChecker) ReqAllowed(ctx context.Context, clientIP net.IP) bool { - if c.whitelist.Allowed(clientIP.String()) { - c.log.Debug("IP is whitelisted", zap.String("ip", clientIP.String())) - return true - } - return !c.IPInRanges(ctx, clientIP) -} - -func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { - // Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses + // convert net.IP to netip.Addr ipAddr, err := ipToAddr(clientIP) if err != nil { c.log.Warn("Invalid IP address format", @@ -78,6 +71,17 @@ func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { return false } + // Check if the IP is whitelisted + if c.whitelist.Whitelisted(ipAddr) { + c.log.Debug("IP is whitelisted", zap.String("ip", clientIP.String())) + return true + } + // Check if the IP is in the blocked ranges + return !c.IPInRanges(ctx, ipAddr) +} + +func (c *IPChecker) IPInRanges(ctx context.Context, ipAddr netip.Addr) bool { + // Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses // Use the normalized string representation for cache keys cacheKey := ipAddr.String() diff --git a/utils/ip/ip_test.go b/utils/ip/ip_test.go index b80c8f9..7e57e75 100644 --- a/utils/ip/ip_test.go +++ b/utils/ip/ip_test.go @@ -84,8 +84,10 @@ func TestIPInRanges(t *testing.T) { t.Run(tt.name, func(t *testing.T) { clientIP := net.ParseIP(tt.ip) assert.NotNil(t, clientIP, "Failed to parse IP") + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip) }) } @@ -98,13 +100,14 @@ func TestIPInRangesCache(t *testing.T) { // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") - + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // First call (not cached) - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (first call)") // Second call (cached) - result = checker.IPInRanges(context.Background(), clientIP) + result = checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (second call)") } @@ -115,16 +118,18 @@ func TestIPInRangesCacheExpiration(t *testing.T) { // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // First call (not cached) - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (first call)") // Wait for cache to expire time.Sleep(100 * time.Millisecond) // Second call (cache expired) - result = checker.IPInRanges(context.Background(), clientIP) + result = checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (second call, cache expired)") } @@ -135,9 +140,10 @@ func TestIPInRangesInvalidCIDR(t *testing.T) { // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") - + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // Call with invalid CIDRs - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.False(t, result, "Expected IP to not be in range due to invalid CIDRs") } @@ -149,8 +155,11 @@ func TestIPInRangesInvalidIP(t *testing.T) { clientIP := net.IP([]byte{1, 2, 3}) // Invalid IP assert.NotNil(t, clientIP, "Failed to create invalid IP") + ipAddr, err := ipToAddr(clientIP) + assert.Error(t, err, "Failed to convert IP to netip.Addr") + // Call with invalid IP - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.False(t, result, "Expected IP to not be in range due to invalid IP") } @@ -220,7 +229,10 @@ func TestPredefinedCIDRGroups(t *testing.T) { clientIP := net.ParseIP(tt.ip) assert.NotNil(t, clientIP, "Failed to parse IP") - result := checker.IPInRanges(context.Background(), clientIP) + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") + + result := checker.IPInRanges(context.Background(), ipAddr) assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip) // Verify error logging for problematic cases diff --git a/utils/ip/whitelist.go b/utils/ip/whitelist.go deleted file mode 100644 index f30cdbf..0000000 --- a/utils/ip/whitelist.go +++ /dev/null @@ -1,59 +0,0 @@ -package ip - -import ( - "fmt" - "net" -) - -type Whitelist struct { - ips map[string]struct{} -} - -// NewWhitelist initializes a new Whitelist from IP strings. -func NewWhitelist(ipStrings []string) (*Whitelist, error) { - wl := &Whitelist{ - ips: make(map[string]struct{}), - } - for _, ipStr := range ipStrings { - ip := net.ParseIP(ipStr) - if ip == nil { - return nil, fmt.Errorf("invalid IP address: %s", ipStr) - } - ip16 := ip.To16() - if ip16 == nil { - return nil, fmt.Errorf("invalid IP address: %s", ipStr) - } - wl.ips[ip16.String()] = struct{}{} - } - return wl, nil -} - -// Allowed checks if the remote address is in the whitelist. -func (wl *Whitelist) Allowed(remoteAddr string) bool { - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - // Handle cases where there's no port - host = remoteAddr - } - - ip := net.ParseIP(host) - if ip == nil { - return false // Invalid IP format - } - - ip16 := ip.To16() - if ip16 == nil { - return false // Shouldn't happen if ParseIP succeeded - } - - _, ok := wl.ips[ip16.String()] - return ok -} - -// Example usage in Caddy middleware: -// func (m YourModule) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { -// if !m.whitelist.Allowed(r.RemoteAddr) { -// return caddyhttp.Error(http.StatusForbidden, nil) -// } -// return next.ServeHTTP(w, r) -// } diff --git a/utils/ip/whitelist/whitelist.go b/utils/ip/whitelist/whitelist.go new file mode 100644 index 0000000..b3be30d --- /dev/null +++ b/utils/ip/whitelist/whitelist.go @@ -0,0 +1,44 @@ +package whitelist + +import ( + "fmt" + "net/netip" +) + +// Whitelist holds the allowed IP addresses. +type Whitelist struct { + ips map[netip.Addr]struct{} // Use netip.Addr for efficient IP handling +} + +// NewWhitelist initializes a new Whitelist from IP strings. +func NewWhitelist(ipStrings []string) (*Whitelist, error) { + wl := &Whitelist{ + ips: make(map[netip.Addr]struct{}), + } + for _, ipStr := range ipStrings { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + wl.ips[ip] = struct{}{} + } + return wl, nil +} + +// Whitelisted checks if the remote address is in the whitelist. +func (wl *Whitelist) Whitelisted(ip netip.Addr) bool { + // Check if the IP is in the whitelist + _, ok := wl.ips[ip] + return ok +} + +// ValidateWhitelist checks if a list of IP strings are valid. +func ValidateWhitelist(ipStrings []string) error { + for _, ipStr := range ipStrings { + _, err := netip.ParseAddr(ipStr) + if err != nil { + return fmt.Errorf("invalid IP address: %s", ipStr) + } + } + return nil +} diff --git a/utils/ip/whitelist/whitelist_test.go b/utils/ip/whitelist/whitelist_test.go new file mode 100644 index 0000000..13da13d --- /dev/null +++ b/utils/ip/whitelist/whitelist_test.go @@ -0,0 +1,133 @@ +package whitelist + +import ( + "net/netip" + "testing" +) + +func TestNewWhitelist(t *testing.T) { + tests := []struct { + name string + ipStrings []string + expectError bool + }{ + { + name: "Valid IPv4 and IPv6", + ipStrings: []string{"192.168.1.1", "2001:db8::1"}, + expectError: false, + }, + { + name: "Invalid IP", + ipStrings: []string{"invalid-ip"}, + expectError: true, + }, + { + name: "Mixed valid and invalid IPs", + ipStrings: []string{"192.168.1.1", "invalid-ip"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wl, err := NewWhitelist(tt.ipStrings) + if tt.expectError { + if err == nil { + t.Error("Expected error for invalid IPs, but got nil") + } + if wl != nil { + t.Error("Expected whitelist to be nil on error, but it was not") + } + } else { + if err != nil { + t.Errorf("Unexpected error for valid IPs: %v", err) + } + if wl == nil { + t.Error("Expected whitelist to be non-nil, but got nil") + } + } + }) + } +} + +func TestWhitelisted(t *testing.T) { + wl, err := NewWhitelist([]string{"192.168.1.1", "2001:db8::1"}) + if err != nil { + t.Fatalf("Failed to create whitelist: %v", err) + } + + tests := []struct { + name string + ip netip.Addr + expected bool + }{ + { + name: "IPv4 in whitelist", + ip: netip.MustParseAddr("192.168.1.1"), + expected: true, + }, + { + name: "IPv6 in whitelist", + ip: netip.MustParseAddr("2001:db8::1"), + expected: true, + }, + { + name: "IPv4 not in whitelist", + ip: netip.MustParseAddr("192.168.1.2"), + expected: false, + }, + { + name: "IPv6 not in whitelist", + ip: netip.MustParseAddr("2001:db8::2"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := wl.Whitelisted(tt.ip) + if result != tt.expected { + t.Errorf("Expected %v for IP %v, but got %v", tt.expected, tt.ip, result) + } + }) + } +} + +func TestValidateWhitelist(t *testing.T) { + tests := []struct { + name string + ipStrings []string + expectError bool + }{ + { + name: "Valid IPv4 and IPv6", + ipStrings: []string{"192.168.1.1", "2001:db8::1"}, + expectError: false, + }, + { + name: "Invalid IP", + ipStrings: []string{"invalid-ip"}, + expectError: true, + }, + { + name: "Mixed valid and invalid IPs", + ipStrings: []string{"192.168.1.1", "invalid-ip"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWhitelist(tt.ipStrings) + if tt.expectError { + if err == nil { + t.Error("Expected error for invalid IPs, but got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error for valid IPs: %v", err) + } + } + }) + } +} From 5d428be0281cd226fb47249421ae1c5dd8248028 Mon Sep 17 00:00:00 2001 From: Jason Cameron Date: Tue, 28 Jan 2025 20:01:04 -0500 Subject: [PATCH 6/6] docs: Correct indentation in Caddyfile examples Corrected the indentation of the `whitelist` directive within the `defender` block in the Caddyfile examples. This improves readability and consistency. --- examples/whitelist/Caddyfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/whitelist/Caddyfile b/examples/whitelist/Caddyfile index 8d19662..319557b 100644 --- a/examples/whitelist/Caddyfile +++ b/examples/whitelist/Caddyfile @@ -9,17 +9,17 @@ # Everything in AWS besides my EC2 instance is blocked from accessing this site. defender block { ranges aws - whitelist 169.254.169.254 # my ec2's public IP. + whitelist 169.254.169.254 # my ec2's public IP. } respond "This is what a human sees" } :81 { bind 127.0.0.1 ::1 - # My localhost ipv6 is blocked but not my ipv4 + # My localhost ipv6 is blocked but not my ipv4 defender block { ranges private - whitelist 127.0.0.1 + whitelist 127.0.0.1 } respond "This is what a ipv4 human sees" }