From 59a5c48f9bb8ac758beff6ed266c77a531160ba7 Mon Sep 17 00:00:00 2001 From: "Baruch Odem (Rothkoff)" Date: Sun, 14 May 2023 18:50:39 +0300 Subject: [PATCH] feat: use sub-command for plugins (#61) BREAKING CHANGE: the CLI command and arguments was changed See the discussion: https://github.com/Checkmarx/2ms/discussions/20#discussioncomment-5647413 I think we don't need to support running *2ms* for multiple plugins at once. It is a rare case and it is confusing the command line arguments. Instead, I'm proposing using *SubCommand* for each plugin. --------- Co-authored-by: Jossef Harush Kadouri --- .github/workflows/pr-validation.yml | 2 +- Makefile | 8 ++- cmd/main.go | 81 +++++++++++---------------- plugins/confluence.go | 85 +++++++++++++++++------------ plugins/discord.go | 63 ++++++++++++++------- plugins/plugins.go | 20 ++++--- plugins/repository.go | 39 +++++++++---- 7 files changed, 173 insertions(+), 125 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 33eb409d..54349a84 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -38,7 +38,7 @@ jobs: - run: make build - name: Run docker and check its output - run: if docker run -t checkmarx/2ms:latest | grep "no scan plugin initialized"; then + run: if docker run -t checkmarx/2ms:latest --version | grep "2ms version"; then echo "Docker ran as expected"; else echo "Docker did not run as expected"; diff --git a/Makefile b/Makefile index 7a7b97d6..4df81607 100644 --- a/Makefile +++ b/Makefile @@ -9,4 +9,10 @@ save: build docker save $(image_name) > $(image_file_name) run: - docker run -it $(image_name) $(ARGS) \ No newline at end of file + docker run -it $(image_name) $(ARGS) + +# To run golangci-lint, you need to install it first: https://golangci-lint.run/usage/install/#local-installation +lint: + golangci-lint run -v -E gofmt --timeout=5m +lint-fix: + golangci-lint run -v -E gofmt --fix --timeout=5m \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go index 373ccabb..5125a6d2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,6 +1,7 @@ package cmd import ( + "fmt" "os" "strings" @@ -18,21 +19,29 @@ import ( const timeSleepInterval = 50 +var Version = "0.0.0" + var rootCmd = &cobra.Command{ Use: "2ms", Short: "2ms Secrets Detection", - Run: execute, Version: Version, } -var Version = "" - var allPlugins = []plugins.IPlugin{ &plugins.ConfluencePlugin{}, &plugins.DiscordPlugin{}, &plugins.RepositoryPlugin{}, } +var channels = plugins.Channels{ + Items: make(chan plugins.Item), + Errors: make(chan error), + WaitGroup: &sync.WaitGroup{}, +} + +var report = reporting.Init() +var secretsChan = make(chan reporting.Secret) + func initLog() { zerolog.SetGlobalLevel(zerolog.InfoLevel) ll, err := rootCmd.Flags().GetString("log-level") @@ -59,18 +68,21 @@ func initLog() { func Execute() { cobra.OnInitialize(initLog) - rootCmd.Flags().BoolP("all", "", true, "scan all plugins") - rootCmd.Flags().StringSlice("tags", []string{"all"}, "select rules to be applied") + rootCmd.PersistentFlags().BoolP("all", "", true, "scan all plugins") + rootCmd.PersistentFlags().StringSlice("tags", []string{"all"}, "select rules to be applied") + rootCmd.PersistentFlags().StringP("log-level", "", "info", "log level (trace, debug, info, warn, error, fatal)") + + rootCmd.PersistentPreRun = preRun + rootCmd.PersistentPostRun = postRun for _, plugin := range allPlugins { - err := plugin.DefineCommandLineArgs(rootCmd) + subCommand, err := plugin.DefineCommand(channels) if err != nil { - log.Fatal().Msg(err.Error()) + log.Fatal().Msg(fmt.Sprintf("error while defining command for plugin %s: %s", plugin.GetName(), err.Error())) } + rootCmd.AddCommand(subCommand) } - rootCmd.PersistentFlags().StringP("log-level", "", "info", "log level (trace, debug, info, warn, error, fatal)") - if err := rootCmd.Execute(); err != nil { log.Fatal().Msg(err.Error()) } @@ -90,7 +102,7 @@ func validateTags(tags []string) { } } -func execute(cmd *cobra.Command, args []string) { +func preRun(cmd *cobra.Command, args []string) { tags, err := cmd.Flags().GetStringSlice("tags") if err != nil { log.Fatal().Msg(err.Error()) @@ -99,51 +111,18 @@ func execute(cmd *cobra.Command, args []string) { validateTags(tags) secrets := secrets.Init(tags) - report := reporting.Init() - - var itemsChannel = make(chan plugins.Item) - var secretsChannel = make(chan reporting.Secret) - var errorsChannel = make(chan error) - - var wg sync.WaitGroup - - // ------------------------------------- - // Get content from plugins - pluginsInitialized := 0 - for _, plugin := range allPlugins { - err := plugin.Initialize(cmd) - if err != nil { - log.Error().Msg(err.Error()) - continue - } - pluginsInitialized += 1 - } - - if pluginsInitialized == 0 { - log.Fatal().Msg("no scan plugin initialized. At least one plugin must be initialized to proceed. Stopping") - os.Exit(1) - } - - for _, plugin := range allPlugins { - if !plugin.IsEnabled() { - continue - } - - wg.Add(1) - go plugin.GetItems(itemsChannel, errorsChannel, &wg) - } go func() { for { select { - case item := <-itemsChannel: + case item := <-channels.Items: report.TotalItemsScanned++ - wg.Add(1) - go secrets.Detect(secretsChannel, item, &wg) - case secret := <-secretsChannel: + channels.WaitGroup.Add(1) + go secrets.Detect(secretsChan, item, channels.WaitGroup) + case secret := <-secretsChan: report.TotalSecretsFound++ report.Results[secret.ID] = append(report.Results[secret.ID], secret) - case err, ok := <-errorsChannel: + case err, ok := <-channels.Errors: if !ok { return } @@ -151,7 +130,10 @@ func execute(cmd *cobra.Command, args []string) { } } }() - wg.Wait() +} + +func postRun(cmd *cobra.Command, args []string) { + channels.WaitGroup.Wait() // Wait for last secret to be added to report time.Sleep(time.Millisecond * timeSleepInterval) @@ -170,5 +152,4 @@ func execute(cmd *cobra.Command, args []string) { } else { os.Exit(0) } - } diff --git a/plugins/confluence.go b/plugins/confluence.go index 9afbf008..2e510730 100644 --- a/plugins/confluence.go +++ b/plugins/confluence.go @@ -2,7 +2,6 @@ package plugins import ( "encoding/json" - "errors" "fmt" "net/http" "strings" @@ -14,11 +13,11 @@ import ( ) const ( - argConfluence = "confluence" - argConfluenceSpaces = "confluence-spaces" - argConfluenceUsername = "confluence-username" - argConfluenceToken = "confluence-token" - argConfluenceHistory = "history" + argUrl = "url" + argSpaces = "spaces" + argUsername = "username" + argToken = "token" + argHistory = "history" confluenceDefaultWindow = 25 confluenceMaxRequests = 500 ) @@ -32,62 +31,76 @@ type ConfluencePlugin struct { History bool } -func (p *ConfluencePlugin) IsEnabled() bool { - return p.Enabled +func (p *ConfluencePlugin) GetName() string { + return "confluence" } func (p *ConfluencePlugin) GetCredentials() (string, string) { return p.Username, p.Token } -func (p *ConfluencePlugin) DefineCommandLineArgs(cmd *cobra.Command) error { - flags := cmd.Flags() - flags.StringP(argConfluence, "", "", "scan confluence url") - flags.StringArray(argConfluenceSpaces, []string{}, "confluence spaces") - flags.StringP(argConfluenceUsername, "", "", "confluence username or email") - flags.StringP(argConfluenceToken, "", "", "confluence token") - flags.BoolP(argConfluenceHistory, "", false, "scan pages history") - return nil +func (p *ConfluencePlugin) DefineCommand(channels Channels) (*cobra.Command, error) { + var confluenceCmd = &cobra.Command{ + Use: p.GetName(), + Short: "Scan confluence", + } + + flags := confluenceCmd.Flags() + flags.StringP(argUrl, "", "", "confluence url") + flags.StringArray(argSpaces, []string{}, "confluence spaces") + flags.StringP(argUsername, "", "", "confluence username or email") + flags.StringP(argToken, "", "", "confluence token") + flags.BoolP(argHistory, "", false, "scan pages history") + err := confluenceCmd.MarkFlagRequired(argUrl) + if err != nil { + return nil, fmt.Errorf("error while marking '%s' flag as required: %w", argUrl, err) + } + + confluenceCmd.Run = func(cmd *cobra.Command, args []string) { + err := p.Initialize(cmd) + if err != nil { + channels.Errors <- fmt.Errorf("error while initializing confluence plugin: %w", err) + return + } + + p.GetItems(channels.Items, channels.Errors, channels.WaitGroup) + } + + return confluenceCmd, nil } func (p *ConfluencePlugin) Initialize(cmd *cobra.Command) error { flags := cmd.Flags() - confluenceUrl, _ := flags.GetString(argConfluence) - if confluenceUrl == "" { - return errors.New("confluence URL arg is missing. Plugin initialization failed") + url, err := flags.GetString(argUrl) + if err != nil { + return fmt.Errorf("error while getting '%s' flag value: %w", argUrl, err) } - confluenceUrl = strings.TrimRight(confluenceUrl, "/") + url = strings.TrimRight(url, "/") - confluenceSpaces, _ := flags.GetStringArray(argConfluenceSpaces) - confluenceUsername, _ := flags.GetString(argConfluenceUsername) - confluenceToken, _ := flags.GetString(argConfluenceToken) - runHistory, _ := flags.GetBool(argConfluenceHistory) + spaces, _ := flags.GetStringArray(argSpaces) + username, _ := flags.GetString(argUsername) + token, _ := flags.GetString(argToken) + runHistory, _ := flags.GetBool(argHistory) - if confluenceUsername == "" || confluenceToken == "" { + if username == "" || token == "" { log.Warn().Msg("confluence credentials were not provided. The scan will be made anonymously only for the public pages") } - p.Token = confluenceToken - p.Username = confluenceUsername - p.URL = confluenceUrl - p.Spaces = confluenceSpaces - p.Enabled = true + p.Token = token + p.Username = username + p.URL = url + p.Spaces = spaces p.History = runHistory p.Limit = make(chan struct{}, confluenceMaxRequests) return nil } func (p *ConfluencePlugin) GetItems(items chan Item, errs chan error, wg *sync.WaitGroup) { - defer wg.Done() - - go p.getSpacesItems(items, errs, wg) - wg.Add(1) + p.getSpacesItems(items, errs, wg) } func (p *ConfluencePlugin) getSpacesItems(items chan Item, errs chan error, wg *sync.WaitGroup) { - defer wg.Done() - spaces, err := p.getSpaces() if err != nil { errs <- err diff --git a/plugins/discord.go b/plugins/discord.go index 320bace1..d41e7dbb 100644 --- a/plugins/discord.go +++ b/plugins/discord.go @@ -12,11 +12,11 @@ import ( ) const ( - discordTokenFlag = "discord-token" - discordServersFlag = "discord-server" - discordChannelsFlag = "discord-channel" - discordFromDateFlag = "discord-duration" - discordMessagesCountFlag = "discord-messages-count" + tokenFlag = "token" + serversFlag = "server" + channelsFlag = "channel" + fromDateFlag = "duration" + messagesCountFlag = "messages-count" ) const defaultDateFrom = time.Hour * 24 * 14 @@ -35,39 +35,64 @@ type DiscordPlugin struct { waitGroup *sync.WaitGroup } -func (p *DiscordPlugin) DefineCommandLineArgs(cmd *cobra.Command) error { - flags := cmd.Flags() +func (p *DiscordPlugin) GetName() string { + return "discord" +} - flags.String(discordTokenFlag, "", "discord token") - flags.StringArray(discordServersFlag, []string{}, "discord servers") - flags.StringArray(discordChannelsFlag, []string{}, "discord channels") - flags.Duration(discordFromDateFlag, defaultDateFrom, "discord from date") - flags.Int(discordMessagesCountFlag, 0, "discord messages count") +func (p *DiscordPlugin) DefineCommand(channels Channels) (*cobra.Command, error) { + var discordCmd = &cobra.Command{ + Use: p.GetName(), + Short: "Scan discord", + } + flags := discordCmd.Flags() - cmd.MarkFlagsRequiredTogether(discordTokenFlag, discordServersFlag) + flags.String(tokenFlag, "", "discord token") + flags.StringArray(serversFlag, []string{}, "discord servers") + flags.StringArray(channelsFlag, []string{}, "discord channels") + flags.Duration(fromDateFlag, defaultDateFrom, "discord from date") + flags.Int(messagesCountFlag, 0, "discord messages count") - return nil + err := discordCmd.MarkFlagRequired(tokenFlag) + if err != nil { + return nil, fmt.Errorf("error while marking '%s' flag as required: %w", tokenFlag, err) + } + err = discordCmd.MarkFlagRequired(serversFlag) + if err != nil { + return nil, fmt.Errorf("error while marking '%s' flag as required: %w", serversFlag, err) + } + + discordCmd.Run = func(cmd *cobra.Command, args []string) { + err := p.Initialize(cmd) + if err != nil { + channels.Errors <- fmt.Errorf("discord plugin initialization failed: %w", err) + return + } + + p.GetItems(channels.Items, channels.Errors, channels.WaitGroup) + } + + return discordCmd, nil } func (p *DiscordPlugin) Initialize(cmd *cobra.Command) error { flags := cmd.Flags() - token, _ := flags.GetString(discordTokenFlag) + token, _ := flags.GetString(tokenFlag) if token == "" { return fmt.Errorf("discord token arg is missing. Plugin initialization failed") } - guilds, _ := flags.GetStringArray(discordServersFlag) + guilds, _ := flags.GetStringArray(serversFlag) if len(guilds) == 0 { return fmt.Errorf("discord servers arg is missing. Plugin initialization failed") } - channels, _ := flags.GetStringArray(discordChannelsFlag) + channels, _ := flags.GetStringArray(channelsFlag) if len(channels) == 0 { log.Warn().Msg("discord channels not provided. Will scan all channels") } - fromDate, _ := flags.GetDuration(discordFromDateFlag) - count, _ := flags.GetInt(discordMessagesCountFlag) + fromDate, _ := flags.GetDuration(fromDateFlag) + count, _ := flags.GetInt(messagesCountFlag) if count == 0 && fromDate == 0 { return fmt.Errorf("discord messages count or from date arg is missing. Plugin initialization failed") } diff --git a/plugins/plugins.go b/plugins/plugins.go index 75334952..c5e2753e 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -1,8 +1,9 @@ package plugins import ( - "github.com/spf13/cobra" "sync" + + "github.com/spf13/cobra" ) type Item struct { @@ -12,14 +13,17 @@ type Item struct { } type Plugin struct { - ID string - Enabled bool - Limit chan struct{} + ID string + Limit chan struct{} +} + +type Channels struct { + Items chan Item + Errors chan error + WaitGroup *sync.WaitGroup } type IPlugin interface { - DefineCommandLineArgs(cmd *cobra.Command) error - Initialize(cmd *cobra.Command) error - GetItems(chan Item, chan error, *sync.WaitGroup) - IsEnabled() bool + GetName() string + DefineCommand(channels Channels) (*cobra.Command, error) } diff --git a/plugins/repository.go b/plugins/repository.go index a2d58938..c4002da8 100644 --- a/plugins/repository.go +++ b/plugins/repository.go @@ -1,7 +1,7 @@ package plugins import ( - "errors" + "fmt" "os" "path/filepath" "sync" @@ -17,25 +17,44 @@ type RepositoryPlugin struct { Path string } -func (p *RepositoryPlugin) IsEnabled() bool { - return p.Enabled +func (p *RepositoryPlugin) GetName() string { + return "repository" } -func (p *RepositoryPlugin) DefineCommandLineArgs(cmd *cobra.Command) error { - flags := cmd.Flags() +func (p *RepositoryPlugin) DefineCommand(channels Channels) (*cobra.Command, error) { + var repositoryCmd = &cobra.Command{ + Use: p.GetName(), + Short: "Scan repository", + } + + flags := repositoryCmd.Flags() flags.String(argRepository, "", "scan repository folder") - return nil + err := repositoryCmd.MarkFlagRequired(argRepository) + if err != nil { + return nil, fmt.Errorf("error while marking '%s' flag as required: %w", argRepository, err) + } + + repositoryCmd.Run = func(cmd *cobra.Command, args []string) { + err := p.Initialize(cmd) + if err != nil { + channels.Errors <- fmt.Errorf("error while initializing plugin: %w", err) + return + } + + p.GetItems(channels.Items, channels.Errors, channels.WaitGroup) + } + + return repositoryCmd, nil } func (p *RepositoryPlugin) Initialize(cmd *cobra.Command) error { flags := cmd.Flags() - directoryPath, _ := flags.GetString(argRepository) - if directoryPath == "" { - return errors.New("path to repository missing. Plugin initialization failed") + directoryPath, err := flags.GetString(argRepository) + if err != nil { + return fmt.Errorf("error while getting '%s' flag value: %w", argRepository, err) } p.Path = directoryPath - p.Enabled = true return nil }