Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add whitelist functionality (#31) #35

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/jasonlovesdoggo/caddy-defender/ranges/data"
"github.com/jasonlovesdoggo/caddy-defender/responders"
"github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist"
"net"
"reflect"
"slices"
Expand All @@ -18,11 +19,13 @@ var responderTypes = []string{"block", "garbage", "custom", "ratelimit"}
// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax:
//
// defender <responder> {
// # IP ranges to block
// ranges
// # Custom message to return to the client when using "custom" middleware (optional)
// message
// }
// # IP ranges to block
// ranges
// # Whitelisted IP addresses to allow to bypass ranges (optional)
// whitelist
// # Custom message to return to the client when using "custom" middleware (optional)
// message
// }
func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
d.Next() // consume directive name

Expand Down Expand Up @@ -52,6 +55,10 @@ func (m *Defender) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
Message := d.Val()
m.Message = Message
case "whitelist":
for d.NextArg() {
m.Whitelist = append(m.Whitelist, d.Val())
}
default:
return d.Errf("unknown subdirective '%s'", d.Val())
}
Expand Down Expand Up @@ -112,10 +119,6 @@ func (m *Defender) Validate() error {
if m.responder == nil {
return fmt.Errorf("responder not configured")
}
if len(m.Ranges) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the new behavior would be to use all ranges, right?

This should be documented in the README.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always been the behavior you can look at the provision step to see where it's set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a PR to document that behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using all predefined ranges, wouldn't this include private and essentially block all access for most if ranges is omitted?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's not a perfect solution. If you have any suggestions for an alternative let me know

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's not a perfect solution. If you have any suggestions for an alternative let me know

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, because that isn't entirely obvious to me, do the aws ranges include the aws_region ranges? If they overlap, maybe a list of clean default ranges would be better.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, because that isn't entirely obvious to me, do the aws ranges include the aws_region ranges? If they overlap, maybe a list of clean default ranges would be better.

Hmm. This does seem like a better solution.

What do you suggest? All services that don't have overlap - private?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that might be a good thing. A list referencing all the default IP ranges that are useful as an out-of-the-box setup.

E.g.: aws azurepubliccloud deepseek gcloud githubcopilot openai

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay okay 👌

If you want to add that into your documentation PR, I'll merge that in as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is ready but includes private right now (to be correct).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by to be correct?

Didn't you say it wouldn't be beneficial to have private be in the defaults?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant it includes private, to reflect the current state.

After your PR that would have to be changed, if you decided to remove private/local from the default rwnge.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant it includes private, to reflect the current state.

After your PR that would have to be changed, if you decided to remove private/local from the default rwnge.

Yes, I removed it from the default range and updated the docs accordingly. You can see plugin.go#L23 for the list

// set the default ranges to be all of the predefined ranges
return fmt.Errorf("no ranges specified, this is required")
}

for _, ipRange := range m.Ranges {
// Check if the range is a predefined key (e.g., "openai")
Expand All @@ -131,6 +134,12 @@ func (m *Defender) Validate() error {
}
}

// Check if the whitelist is valid
err := whitelist.ValidateWhitelist(m.Whitelist)
if err != nil {
return err
}

return nil
}

Expand Down
40 changes: 40 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package caddydefender

import (
"encoding/json"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddytest"
"github.com/jasonlovesdoggo/caddy-defender/ranges/data"
"github.com/jasonlovesdoggo/caddy-defender/responders"
"maps"
"slices"
"sort"
"testing"

"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
Expand Down Expand Up @@ -166,6 +171,7 @@ func TestValidation(t *testing.T) {
def := Defender{
RawResponder: "block",
Ranges: []string{"10.0.0.0/8"},
Whitelist: []string{"126.39.0.3"},
responder: &responders.BlockResponder{},
}
require.NoError(t, def.Validate())
Expand All @@ -186,6 +192,40 @@ func TestValidation(t *testing.T) {
}
require.ErrorContains(t, def.Validate(), "invalid IP range")
})

t.Run("invalid whitelist IP", func(t *testing.T) {
def := Defender{
RawResponder: "block",
Whitelist: []string{"invalid"},
responder: &responders.BlockResponder{},
}
require.ErrorContains(t, def.Validate(), "invalid IP address")
})

t.Run("Missing ranges", func(t *testing.T) {
def := Defender{
RawResponder: "block",
responder: &responders.BlockResponder{},
}
err := def.Provision(caddy.Context{Context: caddy.ActiveContext()})
if err != nil {
return
}

// We must sort the ranges to compare them as the order is not guaranteed
sort.Slice(def.Ranges, func(i, j int) bool {
return def.Ranges[i] < def.Ranges[j]
})

defaultRanges := slices.Collect(maps.Keys(data.IPRanges))

sort.Slice(defaultRanges, func(i, j int) bool {
return defaultRanges[i] < defaultRanges[j]
})

require.Equal(t, defaultRanges, def.Ranges)
})

}

func TestDefenderValidation(t *testing.T) {
Expand Down
25 changes: 25 additions & 0 deletions examples/whitelist/Caddyfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
auto_https off
order defender after header
debug
}

:80 {
bind 127.0.0.1 ::1
# Everything in AWS besides my EC2 instance is blocked from accessing this site.
defender block {
ranges aws
whitelist 169.254.169.254 # my ec2's public IP.
}
respond "This is what a human sees"
}

:81 {
bind 127.0.0.1 ::1
# My localhost ipv6 is blocked but not my ipv4
defender block {
ranges private
whitelist 127.0.0.1
}
respond "This is what a ipv4 human sees"
}
7 changes: 4 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ func (m Defender) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
}
m.log.Debug("Ranges", zap.Strings("ranges", m.Ranges))
// Check if the client IP is in any of the ranges using the optimized checker
if m.ipChecker.IPInRanges(r.Context(), clientIP) {
if m.ipChecker.ReqAllowed(r.Context(), clientIP) {
m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String()))

} else {
m.log.Debug("IP is in ranges", zap.String("ip", clientIP.String()))
return m.responder.ServeHTTP(w, r, next)
} else {
m.log.Debug("IP is not in ranges", zap.String("ip", clientIP.String()))
}

// IP is not in any of the ranges, proceed to the next handler
Expand Down
9 changes: 7 additions & 2 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ type Defender struct {
// Default: All predefined ranges if empty
Ranges []string `json:"ranges,omitempty"`

// An optional whitelist of IP addresses to exclude from blocking. If empty, no IPs are whitelisted.
// NOTE: this only supports IP addresses, not ranges.
// Default: []
Whitelist []string `json:"whitelist,omitempty"`

// Message specifies the custom response message for 'custom' responder type.
// Required only when using 'custom' responder.
Message string `json:"message,omitempty"`
Expand All @@ -81,12 +86,12 @@ func (m *Defender) Provision(ctx caddy.Context) error {

if len(m.Ranges) == 0 {
// set the default ranges to be all of the predefined ranges
m.log.Debug("no ranges specified, this is required")
m.log.Debug("no ranges specified, defaulting to all predefined ranges")
m.Ranges = slices.Collect(maps.Keys(data.IPRanges))
}

// ensure to keep AFTER the ranges are checked (above)
m.ipChecker = ip.NewIPChecker(m.Ranges, m.log)
m.ipChecker = ip.NewIPChecker(m.Ranges, m.Whitelist, m.log)

return nil
}
Expand Down
39 changes: 30 additions & 9 deletions utils/ip/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ip
import (
"context"
"fmt"
Whitelist "github.com/jasonlovesdoggo/caddy-defender/utils/ip/whitelist"
"net"
"net/netip"
"time"
Expand All @@ -14,12 +15,13 @@ import (
)

type IPChecker struct {
table *bart.Table[struct{}]
cache *sturdyc.Client[string]
log *zap.Logger
table *bart.Table[struct{}]
cache *sturdyc.Client[string]
whitelist *Whitelist.Whitelist
log *zap.Logger
}

func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker {
func NewIPChecker(cidrRanges, whitelistedIPs []string, log *zap.Logger) *IPChecker {
const (
capacity = 10000
numShards = 10
Expand All @@ -30,6 +32,13 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker {
retryBaseDelay = 10 * time.Millisecond
)

whitelist, err := Whitelist.NewWhitelist(whitelistedIPs)
if err != nil {
log.Warn("Invalid whitelist IP",
zap.Strings("whitelist", whitelistedIPs),
zap.Error(err))
}

cache := sturdyc.New[string](
capacity,
numShards,
Expand All @@ -45,14 +54,15 @@ func NewIPChecker(cidrRanges []string, log *zap.Logger) *IPChecker {
)

return &IPChecker{
table: buildTable(cidrRanges, log),
cache: cache,
log: log,
table: buildTable(cidrRanges, log),
cache: cache,
log: log,
whitelist: whitelist,
}
}

func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool {
// Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses
func (c *IPChecker) ReqAllowed(ctx context.Context, clientIP net.IP) bool {
// convert net.IP to netip.Addr
ipAddr, err := ipToAddr(clientIP)
if err != nil {
c.log.Warn("Invalid IP address format",
Expand All @@ -61,6 +71,17 @@ func (c *IPChecker) IPInRanges(ctx context.Context, clientIP net.IP) bool {
return false
}

// Check if the IP is whitelisted
if c.whitelist.Whitelisted(ipAddr) {
c.log.Debug("IP is whitelisted", zap.String("ip", clientIP.String()))
return true
}
// Check if the IP is in the blocked ranges
return !c.IPInRanges(ctx, ipAddr)
}

func (c *IPChecker) IPInRanges(ctx context.Context, ipAddr netip.Addr) bool {
// Convert to netip.Addr first to handle IPv4-mapped IPv6 addresses
// Use the normalized string representation for cache keys
cacheKey := ipAddr.String()

Expand Down
44 changes: 28 additions & 16 deletions utils/ip/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestIPInRanges(t *testing.T) {
data.IPRanges = predefinedCIDRs

// Create a new IPChecker with valid CIDRs
checker := NewIPChecker(validCIDRs, testLogger)
checker := NewIPChecker(validCIDRs, []string{}, testLogger)

tests := []struct {
name string
Expand Down Expand Up @@ -84,73 +84,82 @@ func TestIPInRanges(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
clientIP := net.ParseIP(tt.ip)
assert.NotNil(t, clientIP, "Failed to parse IP")
ipAddr, err := ipToAddr(clientIP)
assert.NoError(t, err, "Failed to convert IP to netip.Addr")

result := checker.IPInRanges(context.Background(), clientIP)
result := checker.IPInRanges(context.Background(), ipAddr)
assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip)
})
}
}

func TestIPInRangesCache(t *testing.T) {
// Create a new IPChecker with valid CIDRs
checker := NewIPChecker(validCIDRs, testLogger)
checker := NewIPChecker(validCIDRs, []string{}, testLogger)

// Test IP
clientIP := net.ParseIP("192.168.1.100")
assert.NotNil(t, clientIP, "Failed to parse IP")

ipAddr, err := ipToAddr(clientIP)
assert.NoError(t, err, "Failed to convert IP to netip.Addr")
// First call (not cached)
result := checker.IPInRanges(context.Background(), clientIP)
result := checker.IPInRanges(context.Background(), ipAddr)
assert.True(t, result, "Expected IP to be in range (first call)")

// Second call (cached)
result = checker.IPInRanges(context.Background(), clientIP)
result = checker.IPInRanges(context.Background(), ipAddr)
assert.True(t, result, "Expected IP to be in range (second call)")
}

func TestIPInRangesCacheExpiration(t *testing.T) {
// Create a new IPChecker with a short cache TTL for testing
checker := NewIPChecker(validCIDRs, testLogger)
checker := NewIPChecker(validCIDRs, []string{}, testLogger)

// Test IP
clientIP := net.ParseIP("192.168.1.100")
assert.NotNil(t, clientIP, "Failed to parse IP")
ipAddr, err := ipToAddr(clientIP)
assert.NoError(t, err, "Failed to convert IP to netip.Addr")

// First call (not cached)
result := checker.IPInRanges(context.Background(), clientIP)
result := checker.IPInRanges(context.Background(), ipAddr)
assert.True(t, result, "Expected IP to be in range (first call)")

// Wait for cache to expire
time.Sleep(100 * time.Millisecond)

// Second call (cache expired)
result = checker.IPInRanges(context.Background(), clientIP)
result = checker.IPInRanges(context.Background(), ipAddr)
assert.True(t, result, "Expected IP to be in range (second call, cache expired)")
}

func TestIPInRangesInvalidCIDR(t *testing.T) {
// Create a new IPChecker with invalid CIDRs
checker := NewIPChecker(invalidCIDRs, testLogger)
checker := NewIPChecker(invalidCIDRs, []string{}, testLogger)

// Test IP
clientIP := net.ParseIP("192.168.1.100")
assert.NotNil(t, clientIP, "Failed to parse IP")

ipAddr, err := ipToAddr(clientIP)
assert.NoError(t, err, "Failed to convert IP to netip.Addr")
// Call with invalid CIDRs
result := checker.IPInRanges(context.Background(), clientIP)
result := checker.IPInRanges(context.Background(), ipAddr)
assert.False(t, result, "Expected IP to not be in range due to invalid CIDRs")
}

func TestIPInRangesInvalidIP(t *testing.T) {
// Create a new IPChecker with valid CIDRs
checker := NewIPChecker(validCIDRs, testLogger)
checker := NewIPChecker(validCIDRs, []string{}, testLogger)

// Test invalid IP
clientIP := net.IP([]byte{1, 2, 3}) // Invalid IP
assert.NotNil(t, clientIP, "Failed to create invalid IP")

ipAddr, err := ipToAddr(clientIP)
assert.Error(t, err, "Failed to convert IP to netip.Addr")

// Call with invalid IP
result := checker.IPInRanges(context.Background(), clientIP)
result := checker.IPInRanges(context.Background(), ipAddr)
assert.False(t, result, "Expected IP to not be in range due to invalid IP")
}

Expand Down Expand Up @@ -216,11 +225,14 @@ func TestPredefinedCIDRGroups(t *testing.T) {
return nil
}))

checker := NewIPChecker(tt.groups, logger)
checker := NewIPChecker(tt.groups, []string{}, logger)
clientIP := net.ParseIP(tt.ip)
assert.NotNil(t, clientIP, "Failed to parse IP")

result := checker.IPInRanges(context.Background(), clientIP)
ipAddr, err := ipToAddr(clientIP)
assert.NoError(t, err, "Failed to convert IP to netip.Addr")

result := checker.IPInRanges(context.Background(), ipAddr)
assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip)

// Verify error logging for problematic cases
Expand Down
Loading
Loading