diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 7278284f..224d306d 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -381,3 +381,66 @@ func GetEmailFromToken(token string) string { } return email } + +func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider { + if authProviderFlag == "" { + return "" + } + if authProviderFlag == "nvidia" { + return CredentialProviderKAS + } + return CredentialProviderAuth0 +} + +func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth { + // the default authenticator + var authenticator OAuth = Auth0Authenticator{ + Issuer: "https://brevdev.us.auth0.com/", + Audience: "https://brevdev.us.auth0.com/api/v2/", + ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk", + DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code", + OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token", + } + + shouldPromptEmail := false + if email == "" && tokens != nil && tokens.AccessToken != "" { + email = GetEmailFromToken(tokens.AccessToken) + shouldPromptEmail = true + } + + authRetriever := NewOAuthRetriever([]OAuth{ + authenticator, + NewKasAuthenticator( + email, + "https://api.ngc.nvidia.com", + "https://login.nvidia.com", + shouldPromptEmail, + "https://brev.nvidia.com", + ), + }) + + if tokens != nil && tokens.AccessToken != "" { + authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken) + if errr != nil { + fmt.Printf("%v\n", errr) + } else { + authenticator = authenticatorFromToken + } + } + + if authProvider != "" { + provider := AuthProviderFlagToCredentialProvider(authProvider) + oauth, errr := authRetriever.GetByProvider(provider) + if errr != nil { + fmt.Printf("%v\n", errr) + } else { + authenticator = oauth + } + } + + if authenticator.GetCredentialProvider() == CredentialProviderKAS { + config.ConsoleBaseURL = "https://brev.nvidia.com" + } + + return authenticator +} diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 4718f622..6830dba2 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -2,7 +2,6 @@ package cmd import ( - "flag" "fmt" "github.com/brevdev/brev-cli/pkg/auth" @@ -48,7 +47,6 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/workspacegroups" "github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent" "github.com/brevdev/brev-cli/pkg/config" - "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/featureflag" "github.com/brevdev/brev-cli/pkg/files" "github.com/brevdev/brev-cli/pkg/remoteversion" @@ -61,29 +59,20 @@ import ( ) var ( - userFlag string - emailFlag string - authProviderFlag string + userFlag string + printVersion bool ) -func init() { - flag.StringVar(&emailFlag, "email", "", "email to use for authentication") - flag.StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)") - flag.Parse() -} - func NewDefaultBrevCommand() *cobra.Command { cmd := NewBrevCommand() cmd.PersistentFlags().StringVar(&userFlag, "user", "", "non root user to use for per user configuration of commands run as root") - cmd.PersistentFlags().StringVar(&emailFlag, "email", "", "email to use for authentication") - cmd.PersistentFlags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)") + cmd.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output") return cmd } func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // define brev command // in io.Reader, out io.Writer, err io.Writer t := terminal.New() - var printVersion bool conf := config.NewConstants() fs := files.AppFs @@ -92,64 +81,9 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin NewBasicStore(). WithFileSystem(fs) - tokens, err := fsStore.GetAuthTokens() - if err != nil { - fmt.Printf("%v\n", err) - } - - // the default authenticator - var authenticator auth.OAuth = auth.Auth0Authenticator{ - Issuer: "https://brevdev.us.auth0.com/", - Audience: "https://brevdev.us.auth0.com/api/v2/", - ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk", - DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code", - OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token", - } - - var email string - shouldPromptEmail := false - if tokens != nil && tokens.AccessToken != "" { - email = auth.GetEmailFromToken(tokens.AccessToken) - shouldPromptEmail = true - } - if emailFlag != "" { - email = emailFlag - shouldPromptEmail = false - } - - authRetriever := auth.NewOAuthRetriever([]auth.OAuth{ - authenticator, - auth.NewKasAuthenticator( - email, - "https://api.ngc.nvidia.com", - "https://login.nvidia.com", - shouldPromptEmail, - "https://brev.nvidia.com", - ), - }) - - if tokens != nil && tokens.AccessToken != "" { - authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken) - if errr != nil { - fmt.Printf("%v\n", errr) - } else { - authenticator = authenticatorFromToken - } - } - - if authProviderFlag != "" { - provider := authProviderFlagToCredentialProvider(authProviderFlag) - authenticatorByProviderFlag, errr := authRetriever.GetByProvider(provider) - if errr != nil { - fmt.Printf("%v\n", errr) - } else { - authenticator = authenticatorByProviderFlag - } - } + tokens, _ := fsStore.GetAuthTokens() - if authenticator.GetCredentialProvider() == auth.CredentialProviderKAS { - config.ConsoleBaseURL = "https://brev.nvidia.com" - } + authenticator := auth.StandardLogin("", "", tokens) // super annoying. this is needed to make the import stay _ = color.New(color.FgYellow, color.Bold).SprintFunc() @@ -162,7 +96,7 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin ). WithAuth(loginAuth, store.WithDebug(conf.GetDebugHTTP())) - err = loginCmdStore.SetForbiddenStatusRetryHandler(func() error { + err := loginCmdStore.SetForbiddenStatusRetryHandler(func() error { _, err1 := loginAuth.GetAccessToken() if err1 != nil { return breverrors.WrapAndTrace(err1) @@ -285,8 +219,6 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin cmds.SetUsageTemplate(usageTemplate) - cmds.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output") - createCmdTree(cmds, t, loginCmdStore, noLoginCmdStore, loginAuth) return cmds @@ -537,13 +469,3 @@ var ( _ store.Auth = auth.NoLoginAuth{} _ auth.AuthStore = store.FileStore{} ) - -func authProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider { - if authProviderFlag == "" { - return "" - } - if authProviderFlag == "nvidia" { - return auth.CredentialProviderKAS - } - return auth.CredentialProviderAuth0 -} diff --git a/pkg/cmd/login/login.go b/pkg/cmd/login/login.go index 61b6c9f0..07a4d109 100644 --- a/pkg/cmd/login/login.go +++ b/pkg/cmd/login/login.go @@ -28,6 +28,7 @@ type LoginOptions struct { } type LoginStore interface { + auth.AuthStore GetCurrentUser() (*entity.User, error) CreateUser(idToken string) (*entity.User, error) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) @@ -55,6 +56,8 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. var loginToken string var skipBrowser bool + var emailFlag string + var authProviderFlag string cmd := &cobra.Command{ Annotations: map[string]string{"housekeeping": ""}, @@ -79,40 +82,29 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. }, Args: cmderrors.TransformToValidationError(cobra.NoArgs), RunE: func(cmd *cobra.Command, args []string) error { - err := opts.RunLogin(t, loginToken, skipBrowser) + err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag) if err != nil { // if err is ImportIDEConfigError, log err with sentry but continue if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok { - return breverrors.WrapAndTrace(err) + return err } // todo alert sentry err2 := RunTasksForUser(t) if err2 != nil { err = multierror.Append(err, err2) } - return breverrors.WrapAndTrace(err) + return err } return nil }, } cmd.Flags().StringVarP(&loginToken, "token", "", "", "token provided to auto login") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser") + cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication") + cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)") return cmd } -func (o LoginOptions) checkIfInWorkspace() error { - workspaceID, err := o.LoginStore.GetCurrentWorkspaceID() - if err != nil { - return breverrors.WrapAndTrace(err) - } - if workspaceID != "" { - fmt.Println("can not login to instance") - return breverrors.NewValidationError("can not login to instance") - } - - return nil -} - func (o LoginOptions) loginAndGetOrCreateUser(loginToken string, skipBrowser bool) (*entity.User, error) { if loginToken != "" { err := o.Auth.LoginWithToken(loginToken) @@ -158,12 +150,21 @@ func (o LoginOptions) getOrCreateOrg(username string) (*entity.Organization, err return org, nil } -func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool) error { - err := o.checkIfInWorkspace() - if err != nil { - return breverrors.WrapAndTrace(err) +func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error { + tokens, _ := o.LoginStore.GetAuthTokens() + + if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" { + return breverrors.NewValidationError("auth provider must be nvidia or legacy") } + authenticator := auth.StandardLogin(authProviderFlag, emailFlag, tokens) + + if emailFlag != "" && authenticator.GetCredentialProvider() != auth.CredentialProviderKAS { + return breverrors.NewValidationError("email flag can only be used with nvidia auth provider") + } + + o.Auth = auth.NewAuth(o.LoginStore, authenticator) + caretType := color.New(color.FgGreen, color.Bold).SprintFunc() fmt.Print("\n") fmt.Println(" ", caretType("▸"), " Starting Login")