diff --git a/config.go b/config.go index 2d378d2..f4235f8 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/jasonlovesdoggo/caddy-defender/ranges/data" "github.com/jasonlovesdoggo/caddy-defender/responders" + "github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist" "net" "reflect" "slices" @@ -18,11 +19,13 @@ var responderTypes = []string{"block", "garbage", "custom", "ratelimit"} // UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax: // // defender { -// # IP ranges to block -// ranges -// # Custom message to return to the client when using "custom" middleware (optional) -// message -// } +// # IP ranges to block +// ranges +// # Whitelisted IP addresses to allow to bypass ranges (optional) +// whitelist +// # Custom message to return to the client when using "custom" middleware (optional) +// message +// } func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { d.Next() // consume directive name @@ -52,6 +55,10 @@ func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } Message := d.Val() m.Message = Message + case "whitelist": + for d.NextArg() { + m.Whitelist = append(m.Whitelist, d.Val()) + } default: return d.Errf("unknown subdirective '%s'", d.Val()) } @@ -112,10 +119,6 @@ func (m *Defender) Validate() error { if m.responder == nil { return fmt.Errorf("responder not configured") } - if len(m.Ranges) == 0 { - // set the default ranges to be all of the predefined ranges - return fmt.Errorf("no ranges specified, this is required") - } for _, ipRange := range m.Ranges { // Check if the range is a predefined key (e.g., "openai") @@ -131,6 +134,12 @@ func (m *Defender) Validate() error { } } + // Check if the whitelist is valid + err := whitelist.ValidateWhitelist(m.Whitelist) + if err != nil { + return err + } + return nil } diff --git a/config_test.go b/config_test.go index 3cca4ec..ac3d7f2 100644 --- a/config_test.go +++ b/config_test.go @@ -2,8 +2,13 @@ package caddydefender import ( "encoding/json" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddytest" + "github.com/jasonlovesdoggo/caddy-defender/ranges/data" "github.com/jasonlovesdoggo/caddy-defender/responders" + "maps" + "slices" + "sort" "testing" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" @@ -166,6 +171,7 @@ func TestValidation(t *testing.T) { def := Defender{ RawResponder: "block", Ranges: []string{"10.0.0.0/8"}, + Whitelist: []string{"126.39.0.3"}, responder: &responders.BlockResponder{}, } require.NoError(t, def.Validate()) @@ -186,6 +192,40 @@ func TestValidation(t *testing.T) { } require.ErrorContains(t, def.Validate(), "invalid IP range") }) + + t.Run("invalid whitelist IP", func(t *testing.T) { + def := Defender{ + RawResponder: "block", + Whitelist: []string{"invalid"}, + responder: &responders.BlockResponder{}, + } + require.ErrorContains(t, def.Validate(), "invalid IP address") + }) + + t.Run("Missing ranges", func(t *testing.T) { + def := Defender{ + RawResponder: "block", + responder: &responders.BlockResponder{}, + } + err := def.Provision(caddy.Context{Context: caddy.ActiveContext()}) + if err != nil { + return + } + + // We must sort the ranges to compare them as the order is not guaranteed + sort.Slice(def.Ranges, func(i, j int) bool { + return def.Ranges[i] < def.Ranges[j] + }) + + defaultRanges := slices.Collect(maps.Keys(data.IPRanges)) + + sort.Slice(defaultRanges, func(i, j int) bool { + return defaultRanges[i] < defaultRanges[j] + }) + + require.Equal(t, defaultRanges, def.Ranges) + }) + } func TestDefenderValidation(t *testing.T) { diff --git a/examples/whitelist/Caddyfile b/examples/whitelist/Caddyfile new file mode 100644 index 0000000..319557b --- /dev/null +++ b/examples/whitelist/Caddyfile @@ -0,0 +1,25 @@ +{ + auto_https off + order defender after header + debug +} + +:80 { + bind 127.0.0.1 ::1 + # Everything in AWS besides my EC2 instance is blocked from accessing this site. + defender block { + ranges aws + whitelist 169.254.169.254 # my ec2's public IP. + } + respond "This is what a human sees" +} + +:81 { + bind 127.0.0.1 ::1 + # My localhost ipv6 is blocked but not my ipv4 + defender block { + ranges private + whitelist 127.0.0.1 + } + respond "This is what a ipv4 human sees" +} diff --git a/middleware.go b/middleware.go index 5ca5fe6..533a8b8 100644 --- a/middleware.go +++ b/middleware.go @@ -25,11 +25,12 @@ func (m Defender) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht } m.log.Debug("Ranges", zap.Strings("ranges", m.Ranges)) // Check if the client IP is in any of the ranges using the optimized checker - if m.ipChecker.IPInRanges(r.Context(), clientIP) { + if m.ipChecker.ReqAllowed(r.Context(), clientIP) { + m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String())) + + } else { m.log.Debug("IP is in ranges", zap.String("ip", clientIP.String())) return m.responder.ServeHTTP(w, r, next) - } else { - m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String())) } // IP is not in any of the ranges, proceed to the next handler diff --git a/plugin.go b/plugin.go index 3020f93..4955c1f 100644 --- a/plugin.go +++ b/plugin.go @@ -61,6 +61,11 @@ type Defender struct { // Default: All predefined ranges if empty Ranges []string `json:"ranges,omitempty"` + // An optional whitelist of IP addresses to exclude from blocking. If empty, no IPs are whitelisted. + // NOTE: this only supports IP addresses, not ranges. + // Default: [] + Whitelist []string `json:"whitelist,omitempty"` + // Message specifies the custom response message for 'custom' responder type. // Required only when using 'custom' responder. Message string `json:"message,omitempty"` @@ -81,12 +86,12 @@ func (m *Defender) Provision(ctx caddy.Context) error { if len(m.Ranges) == 0 { // set the default ranges to be all of the predefined ranges - m.log.Debug("no ranges specified, this is required") + m.log.Debug("no ranges specified, defaulting to all predefined ranges") m.Ranges = slices.Collect(maps.Keys(data.IPRanges)) } // ensure to keep AFTER the ranges are checked (above) - m.ipChecker = ip.NewIPChecker(m.Ranges, m.log) + m.ipChecker = ip.NewIPChecker(m.Ranges, m.Whitelist, m.log) return nil } diff --git a/utils/ip/ip.go b/utils/ip/ip.go index d1902fa..cee823e 100644 --- a/utils/ip/ip.go +++ b/utils/ip/ip.go @@ -3,6 +3,7 @@ package ip import ( "context" "fmt" + Whitelist "github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist" "net" "net/netip" "time" @@ -14,12 +15,13 @@ import ( ) type IPChecker struct { - table *bart.Table[struct{}] - cache *sturdyc.Client[string] - log *zap.Logger + table *bart.Table[struct{}] + cache *sturdyc.Client[string] + whitelist *Whitelist.Whitelist + log *zap.Logger } -func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { +func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPChecker { const ( capacity = 10000 numShards = 10 @@ -30,6 +32,13 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { retryBaseDelay = 10 * time.Millisecond ) + whitelist, err := Whitelist.NewWhitelist(whitelistedIPs) + if err != nil { + log.Warn("Invalid whitelist IP", + zap.Strings("whitelist", whitelistedIPs), + zap.Error(err)) + } + cache := sturdyc.New[string]( capacity, numShards, @@ -45,14 +54,15 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker { ) return &IPChecker{ - table: buildTable(cidrRanges, log), - cache: cache, - log: log, + table: buildTable(cidrRanges, log), + cache: cache, + log: log, + whitelist: whitelist, } } -func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { - // Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses +func (c *IPChecker) ReqAllowed(ctx context.Context, clientIP net.IP) bool { + // convert net.IP to netip.Addr ipAddr, err := ipToAddr(clientIP) if err != nil { c.log.Warn("Invalid IP address format", @@ -61,6 +71,17 @@ func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool { return false } + // Check if the IP is whitelisted + if c.whitelist.Whitelisted(ipAddr) { + c.log.Debug("IP is whitelisted", zap.String("ip", clientIP.String())) + return true + } + // Check if the IP is in the blocked ranges + return !c.IPInRanges(ctx, ipAddr) +} + +func (c *IPChecker) IPInRanges(ctx context.Context, ipAddr netip.Addr) bool { + // Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses // Use the normalized string representation for cache keys cacheKey := ipAddr.String() diff --git a/utils/ip/ip_test.go b/utils/ip/ip_test.go index 35ae419..7e57e75 100644 --- a/utils/ip/ip_test.go +++ b/utils/ip/ip_test.go @@ -46,7 +46,7 @@ func TestIPInRanges(t *testing.T) { data.IPRanges = predefinedCIDRs // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) tests := []struct { name string @@ -84,8 +84,10 @@ func TestIPInRanges(t *testing.T) { t.Run(tt.name, func(t *testing.T) { clientIP := net.ParseIP(tt.ip) assert.NotNil(t, clientIP, "Failed to parse IP") + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip) }) } @@ -93,64 +95,71 @@ func TestIPInRanges(t *testing.T) { func TestIPInRangesCache(t *testing.T) { // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") - + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // First call (not cached) - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (first call)") // Second call (cached) - result = checker.IPInRanges(context.Background(), clientIP) + result = checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (second call)") } func TestIPInRangesCacheExpiration(t *testing.T) { // Create a new IPChecker with a short cache TTL for testing - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // First call (not cached) - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (first call)") // Wait for cache to expire time.Sleep(100 * time.Millisecond) // Second call (cache expired) - result = checker.IPInRanges(context.Background(), clientIP) + result = checker.IPInRanges(context.Background(), ipAddr) assert.True(t, result, "Expected IP to be in range (second call, cache expired)") } func TestIPInRangesInvalidCIDR(t *testing.T) { // Create a new IPChecker with invalid CIDRs - checker := NewIPChecker(invalidCIDRs, testLogger) + checker := NewIPChecker(invalidCIDRs, []string{}, testLogger) // Test IP clientIP := net.ParseIP("192.168.1.100") assert.NotNil(t, clientIP, "Failed to parse IP") - + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") // Call with invalid CIDRs - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.False(t, result, "Expected IP to not be in range due to invalid CIDRs") } func TestIPInRangesInvalidIP(t *testing.T) { // Create a new IPChecker with valid CIDRs - checker := NewIPChecker(validCIDRs, testLogger) + checker := NewIPChecker(validCIDRs, []string{}, testLogger) // Test invalid IP clientIP := net.IP([]byte{1, 2, 3}) // Invalid IP assert.NotNil(t, clientIP, "Failed to create invalid IP") + ipAddr, err := ipToAddr(clientIP) + assert.Error(t, err, "Failed to convert IP to netip.Addr") + // Call with invalid IP - result := checker.IPInRanges(context.Background(), clientIP) + result := checker.IPInRanges(context.Background(), ipAddr) assert.False(t, result, "Expected IP to not be in range due to invalid IP") } @@ -216,11 +225,14 @@ func TestPredefinedCIDRGroups(t *testing.T) { return nil })) - checker := NewIPChecker(tt.groups, logger) + checker := NewIPChecker(tt.groups, []string{}, logger) clientIP := net.ParseIP(tt.ip) assert.NotNil(t, clientIP, "Failed to parse IP") - result := checker.IPInRanges(context.Background(), clientIP) + ipAddr, err := ipToAddr(clientIP) + assert.NoError(t, err, "Failed to convert IP to netip.Addr") + + result := checker.IPInRanges(context.Background(), ipAddr) assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip) // Verify error logging for problematic cases diff --git a/utils/ip/whitelist/whitelist.go b/utils/ip/whitelist/whitelist.go new file mode 100644 index 0000000..b3be30d --- /dev/null +++ b/utils/ip/whitelist/whitelist.go @@ -0,0 +1,44 @@ +package whitelist + +import ( + "fmt" + "net/netip" +) + +// Whitelist holds the allowed IP addresses. +type Whitelist struct { + ips map[netip.Addr]struct{} // Use netip.Addr for efficient IP handling +} + +// NewWhitelist initializes a new Whitelist from IP strings. +func NewWhitelist(ipStrings []string) (*Whitelist, error) { + wl := &Whitelist{ + ips: make(map[netip.Addr]struct{}), + } + for _, ipStr := range ipStrings { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + wl.ips[ip] = struct{}{} + } + return wl, nil +} + +// Whitelisted checks if the remote address is in the whitelist. +func (wl *Whitelist) Whitelisted(ip netip.Addr) bool { + // Check if the IP is in the whitelist + _, ok := wl.ips[ip] + return ok +} + +// ValidateWhitelist checks if a list of IP strings are valid. +func ValidateWhitelist(ipStrings []string) error { + for _, ipStr := range ipStrings { + _, err := netip.ParseAddr(ipStr) + if err != nil { + return fmt.Errorf("invalid IP address: %s", ipStr) + } + } + return nil +} diff --git a/utils/ip/whitelist/whitelist_test.go b/utils/ip/whitelist/whitelist_test.go new file mode 100644 index 0000000..13da13d --- /dev/null +++ b/utils/ip/whitelist/whitelist_test.go @@ -0,0 +1,133 @@ +package whitelist + +import ( + "net/netip" + "testing" +) + +func TestNewWhitelist(t *testing.T) { + tests := []struct { + name string + ipStrings []string + expectError bool + }{ + { + name: "Valid IPv4 and IPv6", + ipStrings: []string{"192.168.1.1", "2001:db8::1"}, + expectError: false, + }, + { + name: "Invalid IP", + ipStrings: []string{"invalid-ip"}, + expectError: true, + }, + { + name: "Mixed valid and invalid IPs", + ipStrings: []string{"192.168.1.1", "invalid-ip"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wl, err := NewWhitelist(tt.ipStrings) + if tt.expectError { + if err == nil { + t.Error("Expected error for invalid IPs, but got nil") + } + if wl != nil { + t.Error("Expected whitelist to be nil on error, but it was not") + } + } else { + if err != nil { + t.Errorf("Unexpected error for valid IPs: %v", err) + } + if wl == nil { + t.Error("Expected whitelist to be non-nil, but got nil") + } + } + }) + } +} + +func TestWhitelisted(t *testing.T) { + wl, err := NewWhitelist([]string{"192.168.1.1", "2001:db8::1"}) + if err != nil { + t.Fatalf("Failed to create whitelist: %v", err) + } + + tests := []struct { + name string + ip netip.Addr + expected bool + }{ + { + name: "IPv4 in whitelist", + ip: netip.MustParseAddr("192.168.1.1"), + expected: true, + }, + { + name: "IPv6 in whitelist", + ip: netip.MustParseAddr("2001:db8::1"), + expected: true, + }, + { + name: "IPv4 not in whitelist", + ip: netip.MustParseAddr("192.168.1.2"), + expected: false, + }, + { + name: "IPv6 not in whitelist", + ip: netip.MustParseAddr("2001:db8::2"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := wl.Whitelisted(tt.ip) + if result != tt.expected { + t.Errorf("Expected %v for IP %v, but got %v", tt.expected, tt.ip, result) + } + }) + } +} + +func TestValidateWhitelist(t *testing.T) { + tests := []struct { + name string + ipStrings []string + expectError bool + }{ + { + name: "Valid IPv4 and IPv6", + ipStrings: []string{"192.168.1.1", "2001:db8::1"}, + expectError: false, + }, + { + name: "Invalid IP", + ipStrings: []string{"invalid-ip"}, + expectError: true, + }, + { + name: "Mixed valid and invalid IPs", + ipStrings: []string{"192.168.1.1", "invalid-ip"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWhitelist(tt.ipStrings) + if tt.expectError { + if err == nil { + t.Error("Expected error for invalid IPs, but got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error for valid IPs: %v", err) + } + } + }) + } +}