Skip to content
This repository has been archived by the owner on Dec 4, 2024. It is now read-only.

Added disable event command #7

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions client/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ func (tc *serviceClient) EnableEvent(ctx context.Context, req *pb.EnableEventReq
return tc.client.EnableEvent(ctx, req)
}

func (tc *serviceClient) DisableEvent(ctx context.Context, req *pb.DisableEventRequest) (*pb.DisableEventResponse, error) {
return tc.client.DisableEvent(ctx, req)
}

func (tc *serviceClient) StreamEvents(ctx context.Context, req *pb.StreamEventsRequest) (pb.TraceeService_StreamEventsClient, error) {
return tc.client.StreamEvents(ctx, req)
}
44 changes: 44 additions & 0 deletions cmd/disableEvent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package cmd

import (
"context"

"github.com/ShohamBit/TraceeClient/client"

pb "github.com/aquasecurity/tracee/api/v1beta1"
"github.com/spf13/cobra"
)

var disableEventCmd = &cobra.Command{
Use: "disableEvent [event names...]",
Short: "disable specified events on the server",
Long: "long about the use",
Args: cobra.MinimumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
// Check if args are provided
if len(args) == 0 {
cmd.PrintErrln("Error: no event names provided. Please specify at least one event to disable.")
return // Exit if no arguments
}
disableEvents(cmd, args)
},
}

func disableEvents(cmd *cobra.Command, eventNames []string) {
// Create Tracee gRPC client
client, err := client.NewServiceClient(serverInfo)
if err != nil {
cmd.PrintErrln("Error creating client: ", err)
return // Exit on error
}

// Iterate over event names and disable each one
for _, eventName := range eventNames {
_, err := client.DisableEvent(context.Background(), &pb.DisableEventRequest{Name: eventName})
if err != nil {
cmd.PrintErrln("Error enabling event:", err)
continue // Continue on error with the next event
}
cmd.Println("Disabled event:", eventName)
}
}
71 changes: 71 additions & 0 deletions cmd/disableEvent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package cmd

import (
"bytes"
"testing"
"time"

"github.com/ShohamBit/TraceeClient/mock"
"github.com/ShohamBit/TraceeClient/models"
"github.com/stretchr/testify/assert"
)

var (
DisableEventTests = []models.TestCase{
{
Name: "No events",
Args: []string{"disableEvent"},
ExpectedOutput: "Error: requires at least 1 arg(s), only received 0\n", // Update expected output
},
{
Name: "Single event",
Args: []string{"disableEvent", "event1"},
ExpectedOutput: "Disabled event: event1\n",
},
{
Name: "Multiple events",
Args: []string{"disableEvent", "event1", "event2"},
ExpectedOutput: "Disabled event: event1\nDisabled event: event2\n",
},
}
)

func TestDisableEvent(t *testing.T) {
// Start the mock server
mockServer, err := mock.StartMockServiceServer()
if err != nil {
t.Fatalf("Failed to start mock server: %v", err)
}
defer mockServer.Stop() // Ensure the server is stopped after the test

// Wait for the server to start
time.Sleep(100 * time.Millisecond)

for _, test := range DisableEventTests {
t.Run(test.Name, func(t *testing.T) {
// Capture output
var buf bytes.Buffer
rootCmd.SetOut(&buf)
rootCmd.SetErr(&buf)

// Set arguments for the test
rootCmd.SetArgs(test.Args)

// Execute the command
err := rootCmd.Execute()

// Validate output and error (if any)
output := buf.String()

// If no arguments provided, we expect an error
if test.Name == "No events" {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}

// Check if output matches expected output
assert.Contains(t, output, test.ExpectedOutput)
})
}
}
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func init() {
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(metricsCmd)
rootCmd.AddCommand(enableEventCmd)
rootCmd.AddCommand(disableEventCmd)
rootCmd.AddCommand(streamEventsCmd)

//flags
Expand Down
3 changes: 3 additions & 0 deletions mock/service_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (s *MockServiceServer) GetVersion(ctx context.Context, req *pb.GetVersionRe
func (s *MockServiceServer) EnableEvent(ctx context.Context, req *pb.EnableEventRequest) (*pb.EnableEventResponse, error) {
return &pb.EnableEventResponse{}, nil
}
func (s *MockServiceServer) DisableEvent(ctx context.Context, req *pb.DisableEventRequest) (*pb.DisableEventResponse, error) {
return &pb.DisableEventResponse{}, nil
}

/*
\stream events
Expand Down