diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go new file mode 100644 index 0000000..c126804 --- /dev/null +++ b/cmd/bootstrap.go @@ -0,0 +1,327 @@ +package cmd + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/cedana/cedana-cli/utils" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var username string +var password string + +var bootstrapCmd = &cobra.Command{ + Use: "bootstrap", + Short: "bootstrap cedana with cloud providers", + RunE: func(cmd *cobra.Command, args []string) error { + err := createConfig() + if err != nil { + return err + } + + r := BuildRunner() + + if r.cfg.AuthToken == "" { + return fmt.Errorf("no auth token detected, please login first with cedana-cli login") + } + + if r.cfg.EnabledProviders == nil || len(r.cfg.EnabledProviders) == 0 { + return fmt.Errorf("no providers specified in config, add provider-specific config and enabled providers, regions and try again.") + } + + // assemble cloudInfo from enabledProviders + var cInfo []CloudInfo + for _, provider := range r.cfg.EnabledProviders { + var info CloudInfo + switch provider { + case "aws": + info.Name = "aws" + if r.cfg.AWSConfig.EnabledRegions == nil || len(r.cfg.AWSConfig.EnabledRegions) == 0 { + return fmt.Errorf("no regions specified in config, add regions and try again.") + } + info.Regions = r.cfg.AWSConfig.EnabledRegions + case "azure": + info.Name = "azure" + return fmt.Errorf("azure not yet supported") + case "gcp": + info.Name = "gcp" + return fmt.Errorf("gcp not yet supported") + case "paperspace": + info.Name = "paperspace" + if r.cfg.PaperspaceConfig.EnabledRegions == nil || len(r.cfg.PaperspaceConfig.EnabledRegions) == 0 { + return fmt.Errorf("no regions specified in config, add regions and try again.") + } + info.Regions = r.cfg.PaperspaceConfig.EnabledRegions + } + + cInfo = append(cInfo, info) + } + + r.logger.Info().Msgf("cinfo = %+v", cInfo) + err = r.bootstrap(cInfo, true) + if err != nil { + return err + } + + for _, info := range cInfo { + switch info.Name { + case "aws": + r.logger.Info().Msgf("setting credentials for AWS") + err = r.setCredentialsAWS() + if err != nil { + return err + } + } + } + + return nil + }, +} + +var loginCmd = &cobra.Command{ + Use: "login", + Short: "Login to cedana. Create an account at https://auth.cedana.com/ui/registration", + RunE: func(cmd *cobra.Command, args []string) error { + r := BuildRunner() + + if r.cfg.AuthToken != "" { + err := validateAuthToken() + if err != nil { + return err + } + } + + // auth token not set, prompt for username and password + if (username == "") || (password == "") { + return fmt.Errorf("no username or password specified!") + } + + // Get UI action flow URL + actionUrl, err := getActionURL("https://auth.cedana.com/self-service/login/api") + if err != nil { + return fmt.Errorf("could not get actionUrl for authentication") + } + + token, err := authenticate(actionUrl, username, password) + if err != nil { + r.logger.Fatal().Err(err).Msgf("could not authenticate with cedana server") + } + + fmt.Println("Token:", token) + + // set token in config + viper.Set("auth_token", token) + err = viper.WriteConfig() + if err != nil { + return err + } + + return nil + }, +} + +func validateAuthToken() error { + return nil +} + +func getActionURL(url string) (string, error) { + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + ui, ok := result["ui"].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("unexpected response format") + } + + action, ok := ui["action"].(string) + if !ok { + return "", fmt.Errorf("action URL not found") + } + + return action, nil +} + +func authenticate(actionUrl, email, password string) (string, error) { + authData := map[string]string{ + "identifier": email, + "password": password, + "method": "password", + } + data, err := json.Marshal(authData) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", actionUrl, bytes.NewBuffer(data)) + if err != nil { + return "", err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + token, ok := result["session_token"].(string) + if !ok { + return "", fmt.Errorf("unexpected response format") + } + + return token, nil +} + +func createConfig() error { + homeDir := os.Getenv("HOME") + configFolderPath := filepath.Join(homeDir, ".cedana") + // check that $HOME/.cedana folder exists - create if it doesn't + _, err := os.Stat(configFolderPath) + if err != nil { + err = os.Mkdir(configFolderPath, 0o755) + if err != nil { + return err + } + } + + _, err = os.OpenFile(filepath.Join(homeDir, "/.cedana/cedana_config.json"), 0, 0o644) + if errors.Is(err, os.ErrNotExist) { + // copy template, use viper to set programatically + err = utils.CreateCedanaConfig(filepath.Join(configFolderPath, "cedana_config.json"), username) + if err != nil { + return err + } + } + return nil +} + +type CloudInfo struct { + Name string `json:"name"` + Regions []string `json:"regions"` +} + +type bootstrapRequest struct { + SessionToken string `json:"-"` + CloudInfo []CloudInfo `json:"cloud_info"` + LeaveRunning bool `json:"leaveRunning"` +} + +func (r *Runner) bootstrap(cloudInfo []CloudInfo, leaveRunning bool) error { + br := bootstrapRequest{ + SessionToken: r.cfg.AuthToken, + CloudInfo: cloudInfo, + LeaveRunning: leaveRunning, + } + + jsonBody, err := json.Marshal(br) + if err != nil { + return err + } + + url := r.cfg.MarketServiceUrl + "/" + "/bootstrap" + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+r.cfg.AuthToken) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + if err != nil { + return fmt.Errorf("request failed with status code: %d and error: %s", resp.StatusCode, err.Error()) + } + + r.logger.Info().Msgf("Bootstrap completed") + return nil +} + +type setCredentialsRequestAWS struct { + AccessKeyID string `json:"access_key_id"` + SecretKey string `json:"secret_access_key"` +} + +func (r *Runner) setCredentialsAWS() error { + if r.cfg.AWSConfig.AccessKeyID == "" || r.cfg.AWSConfig.SecretAccessKey == "" { + return fmt.Errorf("AWS credentials not set") + } + + scr := setCredentialsRequestAWS{ + AccessKeyID: r.cfg.AWSConfig.AccessKeyID, + SecretKey: r.cfg.AWSConfig.SecretAccessKey, + } + + jsonBody, err := json.Marshal(scr) + if err != nil { + return err + } + + url := r.cfg.MarketServiceUrl + "/" + "/cloud/" + "aws" + "/credentials" + + req, err := http.NewRequest("PUT", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+r.cfg.AuthToken) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("request failed with status code: %d", resp.StatusCode) + } + + r.logger.Info().Msgf("AWS credentials set with response %s", string(body)) + + return nil +} + +func init() { + RootCmd.AddCommand(bootstrapCmd) + RootCmd.AddCommand(loginCmd) + loginCmd.Flags().StringVarP(&username, "username", "u", "", "username") + loginCmd.Flags().StringVarP(&password, "password", "p", "", "password") +} diff --git a/cmd/managed/bootstrap.go b/cmd/managed/bootstrap.go deleted file mode 100644 index 258f2a8..0000000 --- a/cmd/managed/bootstrap.go +++ /dev/null @@ -1,594 +0,0 @@ -package managed - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "github.com/cedana/cedana-cli/utils" - "github.com/manifoldco/promptui" - "github.com/spf13/cobra" - "github.com/spf13/viper" - "github.com/tidwall/gjson" - "golang.org/x/term" - - ory "github.com/ory/client-go" - "github.com/ory/x/cmdx" - "github.com/ory/x/stringsx" - "github.com/tidwall/sjson" -) - -var registerCmd = &cobra.Command{ - Use: "login", - Short: "login user with managed platform for access to Cedana", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - logger := utils.GetLogger() - - err := createConfig() - if err != nil { - logger.Fatal().Err(err).Msg("could not create config") - } - - r := BuildRunner() - - // using arg, set url - if args[0] == "" { - return fmt.Errorf("no url provided") - } - - viper.Set("managed_config.market_service_url", args[0]) - err = viper.WriteConfig() - if err != nil { - logger.Fatal().Err(err).Msg("could not write config") - } - - // reload config - r.cfg, err = utils.InitCedanaConfig() - if err != nil { - logger.Fatal().Err(err).Msg("could not set up config") - } - - // if authToken unset, redirect user and get token - if r.cfg.ManagedConfig.AuthToken == "" { - sessionToken, err := r.signin(cmd, cmd.Context(), "") - if err != nil { - return err - } - fmt.Printf("sessionToken: %s\n", sessionToken) - } - - return err - }, -} - -// returns sessionToken -func (r *Runner) signin(cmd *cobra.Command, ctx context.Context, sessionToken string) (string, error) { - // init new ory client - cfg := ory.NewConfiguration() - cfg.Servers = ory.ServerConfigurations{ - { - URL: "https://auth.cedana.com", - }, - } - oryClient := ory.NewAPIClient(cfg) - req := oryClient.FrontendAPI.CreateNativeLoginFlow(ctx) - if len(sessionToken) > 0 { - req = req.XSessionToken(sessionToken).Aal("aal2") - } - - flow, _, err := req.Execute() - if err != nil { - return "", err - } - - var form interface{} = &ory.UpdateLoginFlowWithPasswordMethod{} - method := "password" - if len(sessionToken) > 0 { - var foundTOTP bool - var foundLookup bool - for _, n := range flow.Ui.Nodes { - if n.Group == "totp" { - foundTOTP = true - } else if n.Group == "lookup_secret" { - foundLookup = true - } - } - if !foundLookup && !foundTOTP { - return "", errors.New("only TOTP and lookup secrets are supported for two-step verification in the CLI") - } - - method = "lookup_secret" - if foundTOTP { - form = &ory.UpdateLoginFlowWithTotpMethod{} - method = "totp" - } - } - - type PasswordReader struct{} - - pwReader := func() ([]byte, error) { - return term.ReadPassword(int(os.Stdin.Fd())) - } - if p, ok := cmd.Context().Value(PasswordReader{}).(passwordReader); ok { - pwReader = p - } - - if err := renderForm(bufio.NewReader(cmd.InOrStdin()), pwReader, cmd.ErrOrStderr(), flow.Ui, method, form); err != nil { - return "", err - } - - var body ory.UpdateLoginFlowBody - switch e := form.(type) { - case *ory.UpdateLoginFlowWithTotpMethod: - body.UpdateLoginFlowWithTotpMethod = e - case *ory.UpdateLoginFlowWithPasswordMethod: - body.UpdateLoginFlowWithPasswordMethod = e - default: - panic("unexpected type") - } - - login, _, err := oryClient.FrontendAPI.UpdateLoginFlow(ctx).XSessionToken(sessionToken). - Flow(flow.Id).UpdateLoginFlowBody(body).Execute() - if err != nil { - return "", err - } - - sessionToken = stringsx.Coalesce(*login.SessionToken, sessionToken) - _, _, err = oryClient.FrontendAPI.ToSession(ctx).XSessionToken(sessionToken).Execute() - if err == nil { - return sessionToken, nil - } - - if e, ok := err.(interface{ Body() []byte }); ok { - switch gjson.GetBytes(e.Body(), "error.id").String() { - case "session_aal2_required": - return r.signin(cmd, ctx, sessionToken) - } - } - return "", err -} - -var bootstrapCmd = &cobra.Command{ - Use: "bootstrap", - Short: "bootstrap cedana with cloud providers", - RunE: func(cmd *cobra.Command, args []string) error { - r := BuildRunner() - - if r.cfg.EnabledProviders == nil || len(r.cfg.EnabledProviders) == 0 { - return fmt.Errorf("no providers specified in config, add provider-specific config and enabled providers, regions and try again.") - } - - // assemble cloudInfo from enabledProviders - var cInfo []CloudInfo - for _, provider := range r.cfg.EnabledProviders { - var info CloudInfo - switch provider { - case "aws": - info.Name = "aws" - if r.cfg.AWSConfig.EnabledRegions == nil || len(r.cfg.AWSConfig.EnabledRegions) == 0 { - return fmt.Errorf("no regions specified in config, add regions and try again.") - } - info.Regions = r.cfg.AWSConfig.EnabledRegions - case "azure": - info.Name = "azure" - return fmt.Errorf("azure not yet supported") - case "gcp": - info.Name = "gcp" - return fmt.Errorf("gcp not yet supported") - case "paperspace": - info.Name = "paperspace" - if r.cfg.PaperspaceConfig.EnabledRegions == nil || len(r.cfg.PaperspaceConfig.EnabledRegions) == 0 { - return fmt.Errorf("no regions specified in config, add regions and try again.") - } - info.Regions = r.cfg.PaperspaceConfig.EnabledRegions - } - - cInfo = append(cInfo, info) - } - - r.logger.Info().Msgf("cinfo = %+v", cInfo) - err := r.bootstrap(cInfo, true) - if err != nil { - return err - } - - for _, info := range cInfo { - switch info.Name { - case "aws": - r.logger.Info().Msgf("setting credentials for AWS") - err = r.setCredentialsAWS() - if err != nil { - return err - } - } - } - - return nil - }, -} - -func createConfig() error { - homeDir := os.Getenv("HOME") - configFolderPath := filepath.Join(homeDir, ".cedana") - // check that $HOME/.cedana folder exists - create if it doesn't - _, err := os.Stat(configFolderPath) - if err != nil { - err = os.Mkdir(configFolderPath, 0o755) - if err != nil { - return err - } - } - - _, err = os.OpenFile(filepath.Join(homeDir, "/.cedana/cedana_config.json"), 0, 0o644) - if errors.Is(err, os.ErrNotExist) { - username := "" - prompt := promptui.Prompt{ - Label: "Enter username", - } - username, err = prompt.Run() - if err != nil { - return err - } - // copy template, use viper to set programatically - err = utils.CreateCedanaConfig(filepath.Join(configFolderPath, "cedana_config.json"), username) - if err != nil { - return err - } - } - return nil -} - -type registerRequest struct { - Email string `json:"email"` -} - -type registerResponse struct { - Token string `json:"token"` - Owner string `json:"owner"` -} - -func (r *Runner) register(email string) (*registerResponse, error) { - reg := registerRequest{ - Email: email, - } - - jsonBody, err := json.Marshal(reg) - if err != nil { - return nil, err - } - - url := r.cfg.ManagedConfig.MarketServiceUrl + "/registration" - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("request failed with status code: %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var regResp registerResponse - err = json.Unmarshal(body, ®Resp) - if err != nil { - return nil, err - } - - r.logger.Info().Msgf("Registered user email %s, received unique token %s", reg.Email, regResp.Token) - r.logger.Info().Msgf("setting info in config...") - - viper.Set("managed_config.username", reg.Email) - viper.Set("managed_config.user_id", regResp.Owner) - - viper.WriteConfig() - - return ®Resp, nil - -} - -type validateRegistrationRequest struct { - Password string `json:"password"` - Confirm string `json:"confirm_password"` - Token string `json:"token"` -} - -func (r *Runner) validateRegistration(password, confirm, uid, token string) error { - if password != confirm { - return fmt.Errorf("passwords do not match") - } - - vrr := validateRegistrationRequest{ - Password: password, - Confirm: confirm, - Token: token, - } - - jsonBody, err := json.Marshal(vrr) - if err != nil { - return err - } - - url := r.cfg.ManagedConfig.MarketServiceUrl + "/registration/" + uid + "/validation" - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("request failed with status code: %d", resp.StatusCode) - } - - r.logger.Info().Msgf("password set successfully") - - return nil -} - -type generateJWTRequest struct { - Password string `json:"password"` -} - -type generateJWTResponse struct { - JWT string `json:"jwt"` -} - -type CloudInfo struct { - Name string `json:"name"` - Regions []string `json:"regions"` -} - -type bootstrapRequest struct { - SessionToken string `json:"-"` - CloudInfo []CloudInfo `json:"cloud_info"` - LeaveRunning bool `json:"leaveRunning"` -} - -func (r *Runner) bootstrap(cloudInfo []CloudInfo, leaveRunning bool) error { - br := bootstrapRequest{ - SessionToken: r.cfg.ManagedConfig.AuthToken, - CloudInfo: cloudInfo, - LeaveRunning: leaveRunning, - } - - jsonBody, err := json.Marshal(br) - if err != nil { - return err - } - - url := r.cfg.ManagedConfig.MarketServiceUrl + "/" + "/bootstrap" - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+r.cfg.ManagedConfig.AuthToken) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - defer resp.Body.Close() - - if err != nil { - return fmt.Errorf("request failed with status code: %d and error: %s", resp.StatusCode, err.Error()) - } - - r.logger.Info().Msgf("Bootstrap completed") - return nil -} - -type setCredentialsRequestAWS struct { - AccessKeyID string `json:"access_key_id"` - SecretKey string `json:"secret_access_key"` -} - -func (r *Runner) setCredentialsAWS() error { - if r.cfg.AWSConfig.AccessKeyID == "" || r.cfg.AWSConfig.SecretAccessKey == "" { - return fmt.Errorf("AWS credentials not set") - } - - scr := setCredentialsRequestAWS{ - AccessKeyID: r.cfg.AWSConfig.AccessKeyID, - SecretKey: r.cfg.AWSConfig.SecretAccessKey, - } - - jsonBody, err := json.Marshal(scr) - if err != nil { - return err - } - - url := r.cfg.ManagedConfig.MarketServiceUrl + "/" + "/cloud/" + "aws" + "/credentials" - - req, err := http.NewRequest("PUT", url, bytes.NewBuffer(jsonBody)) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+r.cfg.ManagedConfig.AuthToken) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("request failed with status code: %d", resp.StatusCode) - } - - r.logger.Info().Msgf("AWS credentials set with response %s", string(body)) - - return nil -} - -type passwordReader = func() ([]byte, error) - -func getLabel(attrs *ory.UiNodeInputAttributes, node *ory.UiNode) string { - if attrs.Name == "identifier" { - return fmt.Sprintf("%s: ", "Email") - } else if node.Meta.Label != nil { - return fmt.Sprintf("%s: ", node.Meta.Label.Text) - } else if attrs.Label != nil { - return fmt.Sprintf("%s: ", attrs.Label.Text) - } - return fmt.Sprintf("%s: ", attrs.Name) -} - -func renderForm(stdin *bufio.Reader, pwReader passwordReader, stderr io.Writer, ui ory.UiContainer, method string, out interface{}) (err error) { - for _, message := range ui.Messages { - _, _ = fmt.Fprintf(stderr, "%s\n", message.Text) - } - - for _, node := range ui.Nodes { - for _, message := range node.Messages { - _, _ = fmt.Fprintf(stderr, "%s\n", message.Text) - } - } - - values := json.RawMessage(`{}`) - for k := range ui.Nodes { - node := ui.Nodes[k] - if node.Group != method && node.Group != "default" { - continue - } - - switch node.Type { - case "input": - attrs := node.Attributes.UiNodeInputAttributes - switch attrs.Type { - case "button": - continue - case "submit": - continue - } - - if attrs.Name == "traits.consent.tos" { - for { - ok, err := cmdx.AskScannerForConfirmation(getLabel(attrs, &node), stdin, stderr) - if err != nil { - return err - } - if ok { - break - } - } - values, err = sjson.SetBytes(values, attrs.Name, time.Now().UTC().Format(time.RFC3339)) - if err != nil { - return err - } - continue - } - - if strings.Contains(attrs.Name, "traits.details") { - continue - } - - switch attrs.Type { - case "hidden": - continue - case "checkbox": - result, err := cmdx.AskScannerForConfirmation(getLabel(attrs, &node), stdin, stderr) - if err != nil { - return err - } - - values, err = sjson.SetBytes(values, attrs.Name, result) - if err != nil { - return err - } - case "password": - var password string - for password == "" { - _, _ = fmt.Fprint(stderr, getLabel(attrs, &node)) - v, err := pwReader() - if err != nil { - return err - } - password = strings.ReplaceAll(string(v), "\n", "") - fmt.Println("") - } - - values, err = sjson.SetBytes(values, attrs.Name, password) - if err != nil { - return err - } - default: - var value string - for value == "" { - _, _ = fmt.Fprint(stderr, getLabel(attrs, &node)) - v, err := stdin.ReadString('\n') - if err != nil { - return err - } - value = strings.ReplaceAll(v, "\n", "") - } - - values, err = sjson.SetBytes(values, attrs.Name, value) - if err != nil { - return err - } - } - default: - // Do nothing - } - } - - values, err = sjson.SetBytes(values, "method", method) - if err != nil { - return err - } - - return err -} - -func init() { - managedCmd.AddCommand(registerCmd) - managedCmd.AddCommand(bootstrapCmd) -} diff --git a/cmd/managed/run.go b/cmd/run.go similarity index 86% rename from cmd/managed/run.go rename to cmd/run.go index b9badc4..6850389 100644 --- a/cmd/managed/run.go +++ b/cmd/run.go @@ -1,4 +1,4 @@ -package managed +package cmd import ( "bytes" @@ -10,7 +10,6 @@ import ( "net/http" "os" - "github.com/cedana/cedana-cli/cmd" "github.com/cedana/cedana-cli/utils" "github.com/rs/zerolog" "github.com/spf13/cobra" @@ -51,11 +50,6 @@ func BuildRunner() *Runner { } } -var managedCmd = &cobra.Command{ - Use: "managed", - Short: "Run your workloads on the Cedana system.", -} - var setupTaskCmd = &cobra.Command{ Use: "setup", Short: "Setup a task to run on Cedana", @@ -135,7 +129,7 @@ func (r *Runner) setupTask(encodedJob, taskLabel string) error { return err } - url := r.cfg.ManagedConfig.MarketServiceUrl + "/" + "/task" + url := r.cfg.MarketServiceUrl + "/" + "/task" req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) if err != nil { @@ -143,7 +137,7 @@ func (r *Runner) setupTask(encodedJob, taskLabel string) error { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+r.cfg.ManagedConfig.AuthToken) + req.Header.Set("Authorization", "Bearer "+r.cfg.AuthToken) client := &http.Client{} resp, err := client.Do(req) @@ -173,7 +167,7 @@ type listTaskResponse struct { } func (r *Runner) listTask() error { - url := r.cfg.ManagedConfig.MarketServiceUrl + "/" + "/task" + url := r.cfg.MarketServiceUrl + "/" + "/task" req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -181,7 +175,7 @@ func (r *Runner) listTask() error { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+r.cfg.ManagedConfig.AuthToken) + req.Header.Set("Authorization", "Bearer "+r.cfg.AuthToken) client := &http.Client{} resp, err := client.Do(req) @@ -213,7 +207,7 @@ type runTaskResponse struct { } func (r *Runner) runTask(taskLabel string) error { - url := r.cfg.ManagedConfig.MarketServiceUrl + "/" + "/task/" + taskLabel + "/run" + url := r.cfg.MarketServiceUrl + "/" + "/task/" + taskLabel + "/run" req, err := http.NewRequest("POST", url, nil) if err != nil { @@ -221,7 +215,7 @@ func (r *Runner) runTask(taskLabel string) error { } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+r.cfg.ManagedConfig.AuthToken) + req.Header.Set("Authorization", "Bearer "+r.cfg.AuthToken) client := &http.Client{} resp, err := client.Do(req) @@ -247,10 +241,9 @@ func (r *Runner) runTask(taskLabel string) error { } func init() { - cmd.RootCmd.AddCommand(managedCmd) - managedCmd.AddCommand(setupTaskCmd) - managedCmd.AddCommand(listTasksCmd) - managedCmd.AddCommand(runTaskCmd) + RootCmd.AddCommand(setupTaskCmd) + RootCmd.AddCommand(listTasksCmd) + RootCmd.AddCommand(runTaskCmd) setupTaskCmd.Flags().StringVarP(&jobFile, "job", "j", "", "job file") } diff --git a/main.go b/main.go index b8dd3b6..3e7f509 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ package main // because it's not imported anywhere else. import ( "github.com/cedana/cedana-cli/cmd" - _ "github.com/cedana/cedana-cli/cmd/managed" ) // these get set by goreleaser diff --git a/utils/config.go b/utils/config.go index 72600ff..9c83476 100644 --- a/utils/config.go +++ b/utils/config.go @@ -19,19 +19,14 @@ var ValidProviders = []string{ } type CedanaConfig struct { - CedanaManaged bool `json:"cedana_managed" mapstructure:"cedana_managed"` - ManagedConfig ManagedConfig `json:"managed_config" mapstructure:"managed_config"` + MarketServiceUrl string `json:"market_service_url" mapstructure:"market_service_url"` + AuthToken string `json:"auth_token" mapstructure:"auth_token"` EnabledProviders []string `json:"enabled_providers" mapstructure:"enabled_providers"` KeepRunning bool `json:"keep_running" mapstructure:"keep_running"` AWSConfig AWSConfig `json:"aws" mapstructure:"aws"` PaperspaceConfig PaperspaceConfig `json:"paperspace" mapstructure:"paperspace"` } -type ManagedConfig struct { - MarketServiceUrl string `json:"market_service_url" mapstructure:"market_service_url"` - AuthToken string `json:"auth_token" mapstructure:"auth_token"` -} - type AWSConfig struct { AccessKeyID string `json:"access_key_id" mapstructure:"access_key_id"` SecretAccessKey string `json:"secret_access_key" mapstructure:"secret_access_key"` @@ -110,11 +105,9 @@ func InitCedanaConfig() (*CedanaConfig, error) { // Used in bootstrap to create a placeholder config func CreateCedanaConfig(path, username string) error { sc := &CedanaConfig{ - ManagedConfig: ManagedConfig{ - MarketServiceUrl: "https://market.cedana.com", - AuthToken: "", - }, - EnabledProviders: []string{""}, + MarketServiceUrl: "https://market.cedana.com", + AuthToken: "", + EnabledProviders: []string{"aws"}, AWSConfig: AWSConfig{}, PaperspaceConfig: PaperspaceConfig{}, }