From 05ef807d8b88a386156a5cd2544c0c5da9822217 Mon Sep 17 00:00:00 2001 From: Alexander Bulimov Date: Wed, 18 Sep 2024 06:08:05 -0700 Subject: [PATCH] fix bugs in GetLocationByMap Summary: Fix couple of bugs with location unpacking: * `GetLocationByMap` doesn't handle multi-value unpacking * Added more slice bound checks to harden the code. Reviewed By: deathowl Differential Revision: D62937783 --- dnsrocks/db/rdbdriver.go | 77 +++++++++++--------- dnsrocks/db/rdbdriver_test.go | 130 ++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 33 deletions(-) create mode 100644 dnsrocks/db/rdbdriver_test.go diff --git a/dnsrocks/db/rdbdriver.go b/dnsrocks/db/rdbdriver.go index 4dce624..40192f7 100644 --- a/dnsrocks/db/rdbdriver.go +++ b/dnsrocks/db/rdbdriver.go @@ -15,6 +15,7 @@ package db import ( "bytes" + "encoding/binary" "fmt" "net" "time" @@ -167,42 +168,24 @@ func isIPv4(addr net.IP) bool { return addr != nil && (len(addr) == net.IPv4len || net.IP.Equal(addr[:12], firstIPv4[:12])) } -// GetLocationByMap finds and returns location and mask. If the location is not found, returns nil and 0. -func (r *rdbdriver) GetLocationByMap(ipnet *net.IPNet, mapID []byte, context Context) (loc []byte, mlen uint8, err error) { - // Lookup the subnet in IP MAP: map ID, IP address -> LocID, mask - // Build key prefix: "\000\000\000!{MapID}" - // We prime the key byte array (fullKey) with the key prefix, - // followed by the original IP address in V6 format. - nmap := len(mapID) - fullKey := make([]byte, 4+nmap+net.IPv6len+1) // 4 bytes for prefix, n bytes for mapID, and the rest is IP and masklen - copy(fullKey, ipMapRangePointKeyElement) // prefix, 4 bytes - copy(fullKey[4:], mapID) // mapID, n bytes - copy(fullKey[4+nmap:], ipnet.IP.To16()) - reqMaskLen, _ := ipnet.Mask.Size() - if isIPv4(ipnet.IP) { - reqMaskLen += 128 - 32 - } - copy(fullKey[4+nmap+16:], []byte{uint8(reqMaskLen)}) - - ctx := context.(*rdb.Context) - - // NOTE: Rearranger has merging on adjacent locations with same mask and locID, - // so findClosest() might return the key that will match some other IP. It is fine for our purposes. - // foundVal will consist of mask (1 byte) and LocID (2 bytes); if LocID is null, then there will be mask only. - foundKey, foundVal, err := r.db.FindClosest(fullKey, ctx) - if err != nil { - return nil, 0, err - } +func unpackLocation(foundKey, foundVal []byte) (loc []byte, mlen uint8, err error) { if len(foundVal) == 0 { return nil, 0, nil // consistent with the return at the end of cdbdriver.go:/GetLocationByMap } + if len(foundKey) == 0 { + return nil, 0, fmt.Errorf("empty key, value %v", foundVal) + } if len(foundVal) < 4 { - err = fmt.Errorf("short value: length %d, value %v, map %v", len(foundVal), foundVal, mapID) - return nil, 0, err + return nil, 0, fmt.Errorf("short value: length %d, value %v", len(foundVal), foundVal) + } + // take the first value from the potential multi-value - see ../dnsdata/rdb/rdb_util.go + valLen := binary.LittleEndian.Uint32(foundVal[:4]) + if len(foundVal) < int(valLen)+4 { + return nil, 0, fmt.Errorf("short value: length %d, length from header %d, value %v", len(foundVal), valLen, foundVal) } - foundVal = foundVal[4:] // skip over the multi-value header - see ../dnsdata/rdb/rdb_util.go:/Put + foundVal = foundVal[4 : 4+valLen] // skip over the multi-value header mlen = foundKey[len(foundKey)-1] - switch len(foundVal) { + switch valLen { case 0: // Rearranger will always add /0 mask, so if anything - the empty location will match return nil, mlen, nil @@ -210,7 +193,7 @@ func (r *rdbdriver) GetLocationByMap(ipnet *net.IPNet, mapID []byte, context Con loc = foundVal return loc, mlen, nil default: - if len(foundVal) < 2 { + if valLen < 3 { err = fmt.Errorf("Invalid location length %d, value %v", len(foundVal), foundVal) return nil, 0, err } @@ -219,8 +202,8 @@ func (r *rdbdriver) GetLocationByMap(ipnet *net.IPNet, mapID []byte, context Con return loc, mlen, nil } locLen, foundVal := foundVal[1], foundVal[2:] - if int(locLen) > len(foundVal) { - err = fmt.Errorf("invalid location length byte %d > %d", locLen, len(foundVal)) + if int(locLen) > int(valLen) { + err = fmt.Errorf("invalid location length byte %d > %d", locLen, valLen) return nil, 0, err } loc = foundVal[:locLen] @@ -228,6 +211,34 @@ func (r *rdbdriver) GetLocationByMap(ipnet *net.IPNet, mapID []byte, context Con } } +// GetLocationByMap finds and returns location and mask. If the location is not found, returns nil and 0. +func (r *rdbdriver) GetLocationByMap(ipnet *net.IPNet, mapID []byte, context Context) (loc []byte, mlen uint8, err error) { + // Lookup the subnet in IP MAP: map ID, IP address -> LocID, mask + // Build key prefix: "\000\000\000!{MapID}" + // We prime the key byte array (fullKey) with the key prefix, + // followed by the original IP address in V6 format. + nmap := len(mapID) + fullKey := make([]byte, 4+nmap+net.IPv6len+1) // 4 bytes for prefix, n bytes for mapID, and the rest is IP and masklen + copy(fullKey, ipMapRangePointKeyElement) // prefix, 4 bytes + copy(fullKey[4:], mapID) // mapID, n bytes + copy(fullKey[4+nmap:], ipnet.IP.To16()) + reqMaskLen, _ := ipnet.Mask.Size() + if isIPv4(ipnet.IP) { + reqMaskLen += 128 - 32 + } + copy(fullKey[4+nmap+16:], []byte{uint8(reqMaskLen)}) //nolint:gosec + + ctx := context.(*rdb.Context) + + // NOTE: Rearranger has merging on adjacent locations with same mask and locID, + // so findClosest() might return the key that will match some other IP. It is fine for our purposes. + foundKey, foundVal, err := r.db.FindClosest(fullKey, ctx) + if err != nil { + return nil, 0, err + } + return unpackLocation(foundKey, foundVal) +} + func (r *rdbdriver) Close() error { return r.db.Close() } diff --git a/dnsrocks/db/rdbdriver_test.go b/dnsrocks/db/rdbdriver_test.go new file mode 100644 index 0000000..0877bfd --- /dev/null +++ b/dnsrocks/db/rdbdriver_test.go @@ -0,0 +1,130 @@ +/* +Copyright (c) Meta Platforms, Inc. and affiliates. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package db + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnpackLocation(t *testing.T) { + testCases := []struct { + name string + key []byte + value []byte + wantLocation []byte + wantMask uint8 + wantErr bool + }{ + { + name: "empty key and value", + wantErr: false, + }, + { + name: "empty key", + value: []byte{2, 0, 0, 0, 3, 4}, + wantErr: true, + }, + { + name: "empty location", + key: []byte{1, 2, 3}, + wantErr: false, // empty location is valid + }, + { + name: "short location", + key: []byte{1, 2, 3}, + value: []byte{2, 0, 0}, // value should be at least 4 bytes (uint32 for length of multi-value field) + wantErr: true, + }, + { + name: "not enough bytes for provided length", + key: []byte{1, 2, 3}, + value: []byte{4, 0, 0, 0, 3, 2}, // 4 is the length of the first multi-value field + wantErr: true, + }, + { + name: "not enough bytes for long location", + key: []byte{1, 2, 3}, + value: []byte{6, 0, 0, 0, 255, 8, 3, 4, 5, 6}, // 4 is the length of the first multi-value, 255 is the marker, 8 is the length of the location + wantErr: true, + }, + { + name: "zero length location", + key: []byte{1, 2, 3}, + value: []byte{0, 0, 0, 0}, + wantErr: false, + wantMask: 3, + }, + { + name: "two-byte location", + key: []byte{1, 2, 3}, + value: []byte{2, 0, 0, 0, 82, 10}, // 2 is the length of the first multi-value field, 82, 10 is the location + wantErr: false, + wantLocation: []byte{82, 10}, + wantMask: 3, + }, + { + name: "long location without marker", + key: []byte{1, 2, 3}, + value: []byte{4, 0, 0, 0, 82, 10, 7, 8}, // 4 is the length of the first multi-value field, 82, 10, 7, 8 is the location + wantErr: false, + wantLocation: []byte{82, 10, 7, 8}, + wantMask: 3, + }, + { + name: "long location with marker", + key: []byte{1, 2, 3}, + value: []byte{6, 0, 0, 0, 255, 4, 97, 108, 101, 120}, // 6 is the length of the first multi-value field, 255 is the marker, 4 is the length of the location, 'alex' is the location + wantErr: false, + wantLocation: []byte("alex"), + wantMask: 3, + }, + { + name: "multi-value two-byte location", + key: []byte{1, 2, 3}, + value: []byte{2, 0, 0, 0, 82, 10, 2, 0, 0, 0, 3, 4}, // 2 is the length of the first multi-value field, 82, 10 is the location, 2 is the length of the second multi-value field, 3, 4 is the location + wantErr: false, + wantLocation: []byte{82, 10}, + wantMask: 3, + }, + { + name: "multi-value empty location", + key: []byte{1, 2, 3}, + value: []byte{0, 0, 0, 0, 2, 0, 0, 0, 3, 4}, // 0 is the length of the first multi-value field, 2 is the length of the second multi-value field, 3, 4 is the second location + wantErr: false, + wantMask: 3, + }, + { + name: "multi-value long location", + key: []byte{1, 2, 3}, + value: []byte{6, 0, 0, 0, 255, 4, 97, 108, 101, 120, 2, 0, 0, 0, 3, 4}, // 6 is the length of the first multi-value field, 255 is the marker, 4 is the length of the location, 'alex' is the location, 2 is the length of the second multi-value field, 3, 4 is the second location + wantErr: false, + wantLocation: []byte("alex"), + wantMask: 3, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotLocation, gotMask, err := unpackLocation(tc.key, tc.value) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.wantLocation, gotLocation) + require.Equal(t, tc.wantMask, gotMask) + }) + } +}