Skip to content

Commit

Permalink
move auth and email flag to login
Browse files Browse the repository at this point in the history
  • Loading branch information
theFong committed Nov 15, 2024
1 parent 14ed1ec commit d08c877
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 103 deletions.
63 changes: 63 additions & 0 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,66 @@ func GetEmailFromToken(token string) string {
}
return email
}

func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider {
if authProviderFlag == "" {
return ""
}
if authProviderFlag == "nvidia" {
return CredentialProviderKAS
}
return CredentialProviderAuth0
}

func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth {
// the default authenticator
var authenticator OAuth = Auth0Authenticator{
Issuer: "https://brevdev.us.auth0.com/",
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}

shouldPromptEmail := false
if email == "" && tokens != nil && tokens.AccessToken != "" {
email = GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}

authRetriever := NewOAuthRetriever([]OAuth{
authenticator,
NewKasAuthenticator(
email,
"https://api.ngc.nvidia.com",
"https://login.nvidia.com",
shouldPromptEmail,
"https://brev.nvidia.com",
),
})

if tokens != nil && tokens.AccessToken != "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorFromToken
}
}

if authProvider != "" {
provider := AuthProviderFlagToCredentialProvider(authProvider)
oauth, errr := authRetriever.GetByProvider(provider)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = oauth
}
}

if authenticator.GetCredentialProvider() == CredentialProviderKAS {
config.ConsoleBaseURL = "https://brev.nvidia.com"
}

return authenticator
}
91 changes: 8 additions & 83 deletions pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package cmd

import (
"flag"
"fmt"

"github.com/brevdev/brev-cli/pkg/auth"
Expand Down Expand Up @@ -48,7 +47,6 @@ import (
"github.com/brevdev/brev-cli/pkg/cmd/workspacegroups"
"github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent"
"github.com/brevdev/brev-cli/pkg/config"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/featureflag"
"github.com/brevdev/brev-cli/pkg/files"
"github.com/brevdev/brev-cli/pkg/remoteversion"
Expand All @@ -61,29 +59,23 @@ import (
)

var (
userFlag string
emailFlag string
authProviderFlag string
userFlag string
printVersion bool
)

func init() {
flag.StringVar(&emailFlag, "email", "", "email to use for authentication")
flag.StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
flag.Parse()
}
// We define the global flags twice because the first call to flag.Parse which we use for
// will eat up the flags and we need to define them again here.

func NewDefaultBrevCommand() *cobra.Command {
cmd := NewBrevCommand()
cmd.PersistentFlags().StringVar(&userFlag, "user", "", "non root user to use for per user configuration of commands run as root")
cmd.PersistentFlags().StringVar(&emailFlag, "email", "", "email to use for authentication")
cmd.PersistentFlags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
cmd.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output")
return cmd
}

func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // define brev command
// in io.Reader, out io.Writer, err io.Writer
t := terminal.New()
var printVersion bool

conf := config.NewConstants()
fs := files.AppFs
Expand All @@ -92,64 +84,9 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
NewBasicStore().
WithFileSystem(fs)

tokens, err := fsStore.GetAuthTokens()
if err != nil {
fmt.Printf("%v\n", err)
}

// the default authenticator
var authenticator auth.OAuth = auth.Auth0Authenticator{
Issuer: "https://brevdev.us.auth0.com/",
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}

var email string
shouldPromptEmail := false
if tokens != nil && tokens.AccessToken != "" {
email = auth.GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}
if emailFlag != "" {
email = emailFlag
shouldPromptEmail = false
}
tokens, _ := fsStore.GetAuthTokens()

authRetriever := auth.NewOAuthRetriever([]auth.OAuth{
authenticator,
auth.NewKasAuthenticator(
email,
"https://api.ngc.nvidia.com",
"https://login.nvidia.com",
shouldPromptEmail,
"https://brev.nvidia.com",
),
})

if tokens != nil && tokens.AccessToken != "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorFromToken
}
}

if authProviderFlag != "" {
provider := authProviderFlagToCredentialProvider(authProviderFlag)
authenticatorByProviderFlag, errr := authRetriever.GetByProvider(provider)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorByProviderFlag
}
}

if authenticator.GetCredentialProvider() == auth.CredentialProviderKAS {
config.ConsoleBaseURL = "https://brev.nvidia.com"
}
authenticator := auth.StandardLogin("", "", tokens)

// super annoying. this is needed to make the import stay
_ = color.New(color.FgYellow, color.Bold).SprintFunc()
Expand All @@ -162,7 +99,7 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
).
WithAuth(loginAuth, store.WithDebug(conf.GetDebugHTTP()))

err = loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
err := loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
_, err1 := loginAuth.GetAccessToken()
if err1 != nil {
return breverrors.WrapAndTrace(err1)
Expand Down Expand Up @@ -285,8 +222,6 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin

cmds.SetUsageTemplate(usageTemplate)

cmds.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output")

createCmdTree(cmds, t, loginCmdStore, noLoginCmdStore, loginAuth)

return cmds
Expand Down Expand Up @@ -537,13 +472,3 @@ var (
_ store.Auth = auth.NoLoginAuth{}
_ auth.AuthStore = store.FileStore{}
)

func authProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider {
if authProviderFlag == "" {
return ""
}
if authProviderFlag == "nvidia" {
return auth.CredentialProviderKAS
}
return auth.CredentialProviderAuth0
}
41 changes: 21 additions & 20 deletions pkg/cmd/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type LoginOptions struct {
}

type LoginStore interface {
auth.AuthStore
GetCurrentUser() (*entity.User, error)
CreateUser(idToken string) (*entity.User, error)
GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error)
Expand Down Expand Up @@ -55,6 +56,8 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.

var loginToken string
var skipBrowser bool
var emailFlag string
var authProviderFlag string

cmd := &cobra.Command{
Annotations: map[string]string{"housekeeping": ""},
Expand All @@ -79,40 +82,29 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
},
Args: cmderrors.TransformToValidationError(cobra.NoArgs),
RunE: func(cmd *cobra.Command, args []string) error {
err := opts.RunLogin(t, loginToken, skipBrowser)
err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag)
if err != nil {
// if err is ImportIDEConfigError, log err with sentry but continue
if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok {
return breverrors.WrapAndTrace(err)
return err
}
// todo alert sentry
err2 := RunTasksForUser(t)
if err2 != nil {
err = multierror.Append(err, err2)
}
return breverrors.WrapAndTrace(err)
return err
}
return nil
},
}
cmd.Flags().StringVarP(&loginToken, "token", "", "", "token provided to auto login")
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser")
cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication")
cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
return cmd
}

func (o LoginOptions) checkIfInWorkspace() error {
workspaceID, err := o.LoginStore.GetCurrentWorkspaceID()
if err != nil {
return breverrors.WrapAndTrace(err)
}
if workspaceID != "" {
fmt.Println("can not login to instance")
return breverrors.NewValidationError("can not login to instance")
}

return nil
}

func (o LoginOptions) loginAndGetOrCreateUser(loginToken string, skipBrowser bool) (*entity.User, error) {
if loginToken != "" {
err := o.Auth.LoginWithToken(loginToken)
Expand Down Expand Up @@ -158,12 +150,21 @@ func (o LoginOptions) getOrCreateOrg(username string) (*entity.Organization, err
return org, nil
}

func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool) error {
err := o.checkIfInWorkspace()
if err != nil {
return breverrors.WrapAndTrace(err)
func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error {
tokens, _ := o.LoginStore.GetAuthTokens()

if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" {
return breverrors.NewValidationError("auth provider must be nvidia or legacy")
}

if emailFlag != "" && authProviderFlag != "nvidia" {
return breverrors.NewValidationError("email flag can only be used with nvidia auth provider")
}

authenticator := auth.StandardLogin(authProviderFlag, emailFlag, tokens)

o.Auth = auth.NewAuth(o.LoginStore, authenticator)

caretType := color.New(color.FgGreen, color.Bold).SprintFunc()
fmt.Print("\n")
fmt.Println(" ", caretType("▸"), " Starting Login")
Expand Down

0 comments on commit d08c877

Please sign in to comment.