Skip to content

Commit

Permalink
Fix that Service traffic cannot be allowed through L7 NetworkPolicy
Browse files Browse the repository at this point in the history
Signed-off-by: Hongliang Liu <hongliang.liu@broadcom.com>
  • Loading branch information
hongliangl committed Feb 21, 2025
1 parent 2a83607 commit 4657620
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 122 deletions.
2 changes: 1 addition & 1 deletion cmd/antrea-agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ func run(o *Options) error {
}
var l7Reconciler *l7engine.Reconciler
if l7NetworkPolicyEnabled || l7FlowExporterEnabled {
l7Reconciler = l7engine.NewReconciler()
l7Reconciler = l7engine.NewReconciler(ofClient)
}
networkPolicyController, err := networkpolicy.NewNetworkPolicyController(
antreaClientProvider,
Expand Down
10 changes: 7 additions & 3 deletions pkg/agent/controller/l7flowexporter/l7_flow_export_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ var (
errPodInterfaceNotFound = fmt.Errorf("interface of Pod not found")
)

// startSuricataOnceFn is meant to be overridden for testing.
var startSuricataOnceFn func() error

type L7FlowExporterController struct {
ofClient openflow.Client
interfaceStore interfacestore.InterfaceStore
Expand All @@ -63,7 +66,6 @@ type L7FlowExporterController struct {
namespaceLister corelisters.NamespaceLister
namespaceListerSynced cache.InformerSynced

l7Reconciler *l7engine.Reconciler
podToDirectionMap map[string]v1alpha2.Direction
podToDirectionMapMutex sync.RWMutex

Expand All @@ -87,7 +89,6 @@ func NewL7FlowExporterController(
namespaceInformer: namespaceInformer.Informer(),
namespaceLister: namespaceInformer.Lister(),
namespaceListerSynced: namespaceInformer.Informer().HasSynced,
l7Reconciler: l7Reconciler,
podToDirectionMap: make(map[string]v1alpha2.Direction),
queue: workqueue.NewTypedRateLimitingQueueWithConfig(
workqueue.NewTypedItemExponentialFailureRateLimiter[string](minRetryDelay, maxRetryDelay),
Expand All @@ -111,6 +112,7 @@ func NewL7FlowExporterController(
},
resyncPeriod,
)
startSuricataOnceFn = l7Reconciler.StartSuricataOnce
return l7c
}

Expand Down Expand Up @@ -324,7 +326,9 @@ func (l7c *L7FlowExporterController) syncPod(podNN string) error {
sourceOfPort := []uint32{uint32(podInterfaces[0].OFPort)}

// Start Suricata before starting traffic control mark flows
l7c.l7Reconciler.StartSuricataOnce()
if err := startSuricataOnceFn(); err != nil {
return err
}

oldDirection, exists := l7c.getMirroredDirection(podNN)
if exists {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ func newFakeControllerAndWatcher(t *testing.T, objects []runtime.Object, interfa
ifaceStore.AddInterface(itf)
}

l7Reconciler := l7engine.NewReconciler()
l7Reconciler := l7engine.NewReconciler(nil)
l7w := NewL7FlowExporterController(mockOFClient, ifaceStore, localPodInformer, nsInformer, l7Reconciler)
prevStartSuricataOnceFn := startSuricataOnceFn
startSuricataOnceFn = func() error {
return nil
}
t.Cleanup(func() { startSuricataOnceFn = prevStartSuricataOnceFn })

return &fakeController{
L7FlowExporterController: l7w,
Expand Down
50 changes: 31 additions & 19 deletions pkg/agent/controller/networkpolicy/l7engine/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import (
"k8s.io/klog/v2"

"antrea.io/antrea/pkg/agent/config"
"antrea.io/antrea/pkg/agent/openflow"
v1beta "antrea.io/antrea/pkg/apis/controlplane/v1beta2"
"antrea.io/antrea/pkg/util/logdir"
utilsync "antrea.io/antrea/pkg/util/sync"
)

const (
Expand Down Expand Up @@ -153,10 +155,13 @@ type Reconciler struct {
suricataTenantCache *threadSafeSet[uint32]
suricataTenantHandlerCache *threadSafeSet[uint32]

once sync.Once
ofClient openflow.Client

startSuricataOnce utilsync.OnceWithNoError
initializeL7FlowsOnce utilsync.OnceWithNoError
}

func NewReconciler() *Reconciler {
func NewReconciler(ofClient openflow.Client) *Reconciler {
return &Reconciler{
suricataScFn: suricataSc,
startSuricataFn: startSuricata,
Expand All @@ -166,6 +171,7 @@ func NewReconciler() *Reconciler {
suricataTenantHandlerCache: &threadSafeSet[uint32]{
cached: sets.New[uint32](),
},
ofClient: ofClient,
}
}

Expand Down Expand Up @@ -260,10 +266,15 @@ func convertProtocolTLS(tls *v1beta.TLSProtocol) string {
return strings.Join(keywords, " ")
}

func (r *Reconciler) StartSuricataOnce() {
r.once.Do(func() {
r.startSuricata()
})
func (r *Reconciler) StartSuricataOnce() error {
return r.startSuricataOnce.Do(r.startSuricata)
}

func (r *Reconciler) initializeL7Flows() error {
if err := r.ofClient.InstallL7NetworkPolicyFlows(); err != nil {
return fmt.Errorf("failed to install L7 NetworkPolicy flows: %w", err)
}
return nil
}

func (r *Reconciler) AddRule(ruleID, policyName string, vlanID uint32, l7Protocols []v1beta.L7Protocol) error {
Expand All @@ -272,7 +283,12 @@ func (r *Reconciler) AddRule(ruleID, policyName string, vlanID uint32, l7Protoco
klog.V(5).Infof("AddRule took %v", time.Since(start))
}()

r.StartSuricataOnce()
if err := r.StartSuricataOnce(); err != nil {
return err
}
if err := r.initializeL7FlowsOnce.Do(r.initializeL7Flows); err != nil {
return err
}

// Generate the keyword part used in Suricata rules.
protoKeywords := make(map[string]sets.Set[string])
Expand Down Expand Up @@ -461,29 +477,25 @@ func (r *Reconciler) unregisterSuricataTenantHandler(tenantID, vlanID uint32) (*
return r.suricataScFn(scCmd)
}

func (r *Reconciler) startSuricata() {
func (r *Reconciler) startSuricata() error {
f, err := defaultFS.Create(antreaSuricataConfigPath)
if err != nil {
klog.ErrorS(err, "Failed to create Suricata config file", "FilePath", antreaSuricataConfigPath)
return
return fmt.Errorf("failed to create Suricata config file %s: %w", antreaSuricataConfigPath, err)
}
defer f.Close()
if _, err = f.WriteString(suricataAntreaConfigData); err != nil {
klog.ErrorS(err, "Failed to write Suricata config file", "FilePath", antreaSuricataConfigPath)
return
return fmt.Errorf("failed to write Suricata config file %s: %w", antreaSuricataConfigPath, err)
}

// Open the default Suricata config file /etc/suricata/suricata.yaml.
f, err = defaultFS.OpenFile(defaultSuricataConfigPath, os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
klog.ErrorS(err, "Failed to open default Suricata config file", "FilePath", defaultSuricataConfigPath)
return
return fmt.Errorf("failed to open default Suricata config file %s: %w", defaultSuricataConfigPath, err)
}
defer f.Close()
// Include the config file /etc/suricata/antrea.yaml for Antrea in the default Suricata config file /etc/suricata/suricata.yaml.
if _, err = f.WriteString(fmt.Sprintf("include: %s\n", antreaSuricataConfigPath)); err != nil {
klog.ErrorS(err, "Failed to update default Suricata config file", "FilePath", defaultSuricataConfigPath)
return
return fmt.Errorf("failed to update default Suricata config file %s: %w", defaultSuricataConfigPath, err)
}

r.startSuricataFn()
Expand All @@ -496,10 +508,10 @@ func (r *Reconciler) startSuricata() {
return true, nil
})
if err != nil {
klog.ErrorS(err, "Failed to find Suricata command socket file")
} else {
klog.InfoS("Started Suricata instance successfully")
return fmt.Errorf("failed to find Suricata command socket file: %w", err)
}
klog.InfoS("Started Suricata instance successfully")
return nil
}

func startSuricata() {
Expand Down
38 changes: 36 additions & 2 deletions pkg/agent/controller/networkpolicy/l7engine/reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
package l7engine

import (
"fmt"
"sync"
"sync/atomic"
"testing"

"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"k8s.io/apimachinery/pkg/util/sets"

oftesting "antrea.io/antrea/pkg/agent/openflow/testing"
v1beta "antrea.io/antrea/pkg/apis/controlplane/v1beta2"
)

Expand Down Expand Up @@ -124,7 +130,7 @@ func TestStartSuricata(t *testing.T) {
_, err := defaultFS.Create(defaultSuricataConfigPath)
assert.NoError(t, err)

fe := NewReconciler()
fe := NewReconciler(nil)
fs := newFakeSuricata()
fe.suricataScFn = fs.suricataScFunc
fe.startSuricataFn = fs.startSuricataFn
Expand Down Expand Up @@ -183,11 +189,15 @@ func TestRuleLifecycle(t *testing.T) {
_, err := defaultFS.Create(defaultSuricataConfigPath)
assert.NoError(t, err)

fe := NewReconciler()
ctrl := gomock.NewController(t)
mockOfClient := oftesting.NewMockClient(ctrl)
fe := NewReconciler(mockOfClient)
fs := newFakeSuricata()
fe.suricataScFn = fs.suricataScFunc
fe.startSuricataFn = fs.startSuricataFn

mockOfClient.EXPECT().InstallL7NetworkPolicyFlows().Times(1)

// Test add a L7 NetworkPolicy.
assert.NoError(t, fe.AddRule(ruleID, policyName, vlanID, tc.l7Protocols))

Expand Down Expand Up @@ -225,3 +235,27 @@ func TestRuleLifecycle(t *testing.T) {
})
}
}

func TestInitializeL7FlowsOnce(t *testing.T) {
ctrl := gomock.NewController(t)
mockOfClient := oftesting.NewMockClient(ctrl)
fe := NewReconciler(mockOfClient)

mockOfClient.EXPECT().InstallL7NetworkPolicyFlows().Return(fmt.Errorf("error"))
mockOfClient.EXPECT().InstallL7NetworkPolicyFlows().Return(nil)

var wg sync.WaitGroup
var errOccurred int32
for i := 0; i < 3; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
err := fe.initializeL7FlowsOnce.Do(fe.initializeL7Flows)
if err != nil {
atomic.AddInt32(&errOccurred, 1)
}
}(i)
}
wg.Wait()
require.Equal(t, int32(1), errOccurred)
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func newTestController() (*Controller, *fake.Clientset, *mockReconciler) {
groupIDAllocator := openflow.NewGroupAllocator()
groupCounters := []proxytypes.GroupCounter{proxytypes.NewGroupCounter(groupIDAllocator, ch2)}
fs := afero.NewMemMapFs()
l7reconciler := l7engine.NewReconciler()
l7reconciler := l7engine.NewReconciler(nil)
controller, _ := NewNetworkPolicyController(&antreaClientGetter{clientset},
nil,
nil,
Expand Down
13 changes: 13 additions & 0 deletions pkg/agent/openflow/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ type Client interface {

// SubscribeOFPortStatusMessage registers a channel to listen the OpenFlow PortStatus message.
SubscribeOFPortStatusMessage(statusCh chan *openflow15.PortStatus)

// InstallL7NetworkPolicyFlows will be called only when at least one L7 NetworkPolicy is applied locally.
InstallL7NetworkPolicyFlows() error
}

// GetFlowTableStatus returns an array of flow table status.
Expand Down Expand Up @@ -1704,3 +1707,13 @@ func (c *client) getMeterStats() {
func (c *client) SubscribeOFPortStatusMessage(statusCh chan *openflow15.PortStatus) {
c.bridge.SubscribePortStatusConsumer(statusCh)
}

// InstallL7NetworkPolicyFlows will be called only when at least one L7 NetworkPolicy is applied locally.
func (c *client) InstallL7NetworkPolicyFlows() error {
c.replayMutex.RLock()
defer c.replayMutex.RUnlock()

cacheKey := "l7_np_flows"
flows := c.featureNetworkPolicy.l7NPTrafficControlFlows()
return c.addFlows(c.featureNetworkPolicy.cachedFlows, cacheKey, flows)
}
29 changes: 28 additions & 1 deletion pkg/agent/openflow/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2719,7 +2719,7 @@ func Test_client_ReplayFlows(t *testing.T) {

expectedFlows := append(pipelineDefaultFlows(true /* egressTrafficShapingEnabled */, false /* externalNodeEnabled */, true /* isEncap */, true /* isIPv4 */), egressInitFlows(true)...)
expectedFlows = append(expectedFlows, multicastInitFlows(true)...)
expectedFlows = append(expectedFlows, networkPolicyInitFlows(true, false, false)...)
expectedFlows = append(expectedFlows, networkPolicyInitFlows(true, false)...)
expectedFlows = append(expectedFlows, podConnectivityInitFlows(config.TrafficEncapModeEncap, config.TrafficEncryptionModeNone, false, true, true, true)...)
expectedFlows = append(expectedFlows, serviceInitFlows(true, true, false, false)...)

Expand Down Expand Up @@ -2917,3 +2917,30 @@ func TestSubscribeOFPortStatusMessage(t *testing.T) {
bridge.EXPECT().SubscribePortStatusConsumer(ch).Times(1)
c.SubscribeOFPortStatusMessage(ch)
}

func Test_client_InstallL7NetworkPolicyFlows(t *testing.T) {
ctrl := gomock.NewController(t)
m := opstest.NewMockOFEntryOperations(ctrl)

fc := newFakeClient(m, true, false, config.K8sNode, config.TrafficEncapModeEncap, enableL7NetworkPolicy)
defer resetPipelines()

expectedFlows := []string{
"cookie=0x1020000000000, table=Classifier, priority=200,in_port=11,vlan_tci=0x1000/0x1000 actions=pop_vlan,set_field:0x7/0xf->reg0,goto_table:UnSNAT",
"cookie=0x1020000000000, table=ConntrackZone, priority=212,ip,reg0=0x0/0x800000 actions=set_field:0x800000/0x800000->reg0,ct(table=ConntrackZone,zone=65520)",
"cookie=0x1020000000000, table=ConntrackZone, priority=210,ct_state=+rpl+trk,ct_mark=0x80/0x80,ip actions=goto_table:Output",
"cookie=0x1020000000000, table=ConntrackZone, priority=211,ct_state=+rpl+trk,ip,reg0=0x7/0xf actions=ct(table=L3Forwarding,zone=65520,nat)",
"cookie=0x1020000000000, table=ConntrackZone, priority=211,ct_state=-rpl+trk,ip,reg0=0x7/0xf actions=goto_table:L3Forwarding",
"cookie=0x1020000000000, table=ConntrackZone, priority=210,ct_state=-rpl+trk,ct_mark=0x80/0x80,ip actions=ct(table=ConntrackState,zone=65520,nat)",
"cookie=0x1020000000000, table=TrafficControl, priority=210,reg0=0x7/0xf actions=goto_table:Output",
"cookie=0x1020000000000, table=Output, priority=213,reg0=0x7/0xf actions=output:NXM_NX_REG1[]",
"cookie=0x1020000000000, table=Output, priority=212,ct_mark=0x80/0x80 actions=push_vlan:0x8100,move:NXM_NX_CT_LABEL[64..75]->OXM_OF_VLAN_VID[0..11],output:10",
}

m.EXPECT().AddAll(gomock.Any()).Return(nil).Times(1)
cacheKey := "l7_np_flows"
require.NoError(t, fc.InstallL7NetworkPolicyFlows())
fCacheI, ok := fc.featureNetworkPolicy.cachedFlows.Load(cacheKey)
require.True(t, ok)
assert.ElementsMatch(t, expectedFlows, getFlowStrings(fCacheI))
}
34 changes: 20 additions & 14 deletions pkg/agent/openflow/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (

// Fields using reg.
var (
tunnelVal = uint32(1)
gatewayVal = uint32(2)
localVal = uint32(3)
uplinkVal = uint32(4)
bridgeVal = uint32(5)
tcReturnVal = uint32(6)
tunnelVal = uint32(1)
gatewayVal = uint32(2)
localVal = uint32(3)
uplinkVal = uint32(4)
bridgeVal = uint32(5)
tcReturnVal = uint32(6)
l7NPReturnVal = uint32(7)

outputToPortVal = uint32(1)
outputToControllerVal = uint32(2)
Expand All @@ -37,14 +38,16 @@ var (
// - 3: from local Pods.
// - 4: from uplink port.
// - 5: from bridge local port.
// - 6: from traffic control return port.
PktSourceField = binding.NewRegField(0, 0, 3)
FromTunnelRegMark = binding.NewRegMark(PktSourceField, tunnelVal)
FromGatewayRegMark = binding.NewRegMark(PktSourceField, gatewayVal)
FromPodRegMark = binding.NewRegMark(PktSourceField, localVal)
FromUplinkRegMark = binding.NewRegMark(PktSourceField, uplinkVal)
FromBridgeRegMark = binding.NewRegMark(PktSourceField, bridgeVal)
FromTCReturnRegMark = binding.NewRegMark(PktSourceField, tcReturnVal)
// - 6: from TrafficControl return port.
// - 7: from application-aware engine (L7 NetworkPolicy return port).
PktSourceField = binding.NewRegField(0, 0, 3)
FromTunnelRegMark = binding.NewRegMark(PktSourceField, tunnelVal)
FromGatewayRegMark = binding.NewRegMark(PktSourceField, gatewayVal)
FromPodRegMark = binding.NewRegMark(PktSourceField, localVal)
FromUplinkRegMark = binding.NewRegMark(PktSourceField, uplinkVal)
FromBridgeRegMark = binding.NewRegMark(PktSourceField, bridgeVal)
FromTCReturnRegMark = binding.NewRegMark(PktSourceField, tcReturnVal)
FromL7NPReturnRegMark = binding.NewRegMark(PktSourceField, l7NPReturnVal)
// reg0[4..7]: Field to store the packet destination. Marks in this field include:
// - 1: to tunnel port.
// - 2: to Antrea gateway port.
Expand Down Expand Up @@ -85,6 +88,9 @@ var (
OutputRegField = binding.NewRegField(0, 21, 22)
OutputToOFPortRegMark = binding.NewRegMark(OutputRegField, outputToPortVal)
OutputToControllerRegMark = binding.NewRegMark(OutputRegField, outputToControllerVal)
// reg0[23]:
CtStateNotRestoredRegMark = binding.NewOneBitZeroRegMark(0, 23)
CtStateRestoredRegMark = binding.NewOneBitRegMark(0, 23)
// reg0[25..32]: Field to indicate Antrea-native policy packetIn operations
PacketInOperationField = binding.NewRegField(0, 25, 32)

Expand Down
Loading

0 comments on commit 4657620

Please sign in to comment.