Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move authflags to logincmd #211

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,66 @@ func GetEmailFromToken(token string) string {
}
return email
}

func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider {
if authProviderFlag == "" {
return ""
}
if authProviderFlag == "nvidia" {
return CredentialProviderKAS
}
return CredentialProviderAuth0
}

func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth {
// the default authenticator
var authenticator OAuth = Auth0Authenticator{
Issuer: "https://brevdev.us.auth0.com/",
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}

shouldPromptEmail := false
if email == "" && tokens != nil && tokens.AccessToken != "" {
email = GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}

authRetriever := NewOAuthRetriever([]OAuth{
authenticator,
NewKasAuthenticator(
email,
"https://api.ngc.nvidia.com",
"https://login.nvidia.com",
shouldPromptEmail,
"https://brev.nvidia.com",
),
})

if tokens != nil && tokens.AccessToken != "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorFromToken
}
}

if authProvider != "" {
provider := AuthProviderFlagToCredentialProvider(authProvider)
oauth, errr := authRetriever.GetByProvider(provider)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = oauth
}
}

if authenticator.GetCredentialProvider() == CredentialProviderKAS {
config.ConsoleBaseURL = "https://brev.nvidia.com"
}

return authenticator
}
90 changes: 6 additions & 84 deletions pkg/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package cmd

import (
"flag"
"fmt"

"github.com/brevdev/brev-cli/pkg/auth"
Expand Down Expand Up @@ -48,7 +47,6 @@ import (
"github.com/brevdev/brev-cli/pkg/cmd/workspacegroups"
"github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent"
"github.com/brevdev/brev-cli/pkg/config"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/featureflag"
"github.com/brevdev/brev-cli/pkg/files"
"github.com/brevdev/brev-cli/pkg/remoteversion"
Expand All @@ -61,29 +59,20 @@ import (
)

var (
userFlag string
emailFlag string
authProviderFlag string
userFlag string
printVersion bool
)

func init() {
flag.StringVar(&emailFlag, "email", "", "email to use for authentication")
flag.StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
flag.Parse()
}

func NewDefaultBrevCommand() *cobra.Command {
cmd := NewBrevCommand()
cmd.PersistentFlags().StringVar(&userFlag, "user", "", "non root user to use for per user configuration of commands run as root")
cmd.PersistentFlags().StringVar(&emailFlag, "email", "", "email to use for authentication")
cmd.PersistentFlags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
cmd.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output")
return cmd
}

func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // define brev command
// in io.Reader, out io.Writer, err io.Writer
t := terminal.New()
var printVersion bool

conf := config.NewConstants()
fs := files.AppFs
Expand All @@ -92,64 +81,9 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
NewBasicStore().
WithFileSystem(fs)

tokens, err := fsStore.GetAuthTokens()
if err != nil {
fmt.Printf("%v\n", err)
}

// the default authenticator
var authenticator auth.OAuth = auth.Auth0Authenticator{
Issuer: "https://brevdev.us.auth0.com/",
Audience: "https://brevdev.us.auth0.com/api/v2/",
ClientID: "JaqJRLEsdat5w7Tb0WqmTxzIeqwqepmk",
DeviceCodeEndpoint: "https://brevdev.us.auth0.com/oauth/device/code",
OauthTokenEndpoint: "https://brevdev.us.auth0.com/oauth/token",
}

var email string
shouldPromptEmail := false
if tokens != nil && tokens.AccessToken != "" {
email = auth.GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}
if emailFlag != "" {
email = emailFlag
shouldPromptEmail = false
}

authRetriever := auth.NewOAuthRetriever([]auth.OAuth{
authenticator,
auth.NewKasAuthenticator(
email,
"https://api.ngc.nvidia.com",
"https://login.nvidia.com",
shouldPromptEmail,
"https://brev.nvidia.com",
),
})

if tokens != nil && tokens.AccessToken != "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorFromToken
}
}

if authProviderFlag != "" {
provider := authProviderFlagToCredentialProvider(authProviderFlag)
authenticatorByProviderFlag, errr := authRetriever.GetByProvider(provider)
if errr != nil {
fmt.Printf("%v\n", errr)
} else {
authenticator = authenticatorByProviderFlag
}
}
tokens, _ := fsStore.GetAuthTokens()

if authenticator.GetCredentialProvider() == auth.CredentialProviderKAS {
config.ConsoleBaseURL = "https://brev.nvidia.com"
}
authenticator := auth.StandardLogin("", "", tokens)

// super annoying. this is needed to make the import stay
_ = color.New(color.FgYellow, color.Bold).SprintFunc()
Expand All @@ -162,7 +96,7 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
).
WithAuth(loginAuth, store.WithDebug(conf.GetDebugHTTP()))

err = loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
err := loginCmdStore.SetForbiddenStatusRetryHandler(func() error {
_, err1 := loginAuth.GetAccessToken()
if err1 != nil {
return breverrors.WrapAndTrace(err1)
Expand Down Expand Up @@ -285,8 +219,6 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin

cmds.SetUsageTemplate(usageTemplate)

cmds.PersistentFlags().BoolVar(&printVersion, "version", false, "Print version output")

createCmdTree(cmds, t, loginCmdStore, noLoginCmdStore, loginAuth)

return cmds
Expand Down Expand Up @@ -537,13 +469,3 @@ var (
_ store.Auth = auth.NoLoginAuth{}
_ auth.AuthStore = store.FileStore{}
)

func authProviderFlagToCredentialProvider(authProviderFlag string) entity.CredentialProvider {
if authProviderFlag == "" {
return ""
}
if authProviderFlag == "nvidia" {
return auth.CredentialProviderKAS
}
return auth.CredentialProviderAuth0
}
41 changes: 21 additions & 20 deletions pkg/cmd/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
}

type LoginStore interface {
auth.AuthStore
GetCurrentUser() (*entity.User, error)
CreateUser(idToken string) (*entity.User, error)
GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error)
Expand Down Expand Up @@ -55,6 +56,8 @@

var loginToken string
var skipBrowser bool
var emailFlag string
var authProviderFlag string

cmd := &cobra.Command{
Annotations: map[string]string{"housekeeping": ""},
Expand All @@ -79,40 +82,29 @@
},
Args: cmderrors.TransformToValidationError(cobra.NoArgs),
RunE: func(cmd *cobra.Command, args []string) error {
err := opts.RunLogin(t, loginToken, skipBrowser)
err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag)
if err != nil {
// if err is ImportIDEConfigError, log err with sentry but continue
if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok {
return breverrors.WrapAndTrace(err)
return err
}
// todo alert sentry
err2 := RunTasksForUser(t)
if err2 != nil {
err = multierror.Append(err, err2)
}
return breverrors.WrapAndTrace(err)
return err

Check failure on line 96 in pkg/cmd/login/login.go

View workflow job for this annotation

GitHub Actions / ci (ubuntu-20.04)

error returned from external package is unwrapped: sig: func github.com/hashicorp/go-multierror.Append(err error, errs ...error) *github.com/hashicorp/go-multierror.Error (wrapcheck)
}
return nil
},
}
cmd.Flags().StringVarP(&loginToken, "token", "", "", "token provided to auto login")
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser")
cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication")
cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is legacy)")
return cmd
}

func (o LoginOptions) checkIfInWorkspace() error {
workspaceID, err := o.LoginStore.GetCurrentWorkspaceID()
if err != nil {
return breverrors.WrapAndTrace(err)
}
if workspaceID != "" {
fmt.Println("can not login to instance")
return breverrors.NewValidationError("can not login to instance")
}

return nil
}

func (o LoginOptions) loginAndGetOrCreateUser(loginToken string, skipBrowser bool) (*entity.User, error) {
if loginToken != "" {
err := o.Auth.LoginWithToken(loginToken)
Expand Down Expand Up @@ -158,12 +150,21 @@
return org, nil
}

func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool) error {
err := o.checkIfInWorkspace()
if err != nil {
return breverrors.WrapAndTrace(err)
func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error {
tokens, _ := o.LoginStore.GetAuthTokens()

if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" {
return breverrors.NewValidationError("auth provider must be nvidia or legacy")
}

authenticator := auth.StandardLogin(authProviderFlag, emailFlag, tokens)

if emailFlag != "" && authenticator.GetCredentialProvider() != auth.CredentialProviderKAS {
return breverrors.NewValidationError("email flag can only be used with nvidia auth provider")
}

o.Auth = auth.NewAuth(o.LoginStore, authenticator)

caretType := color.New(color.FgGreen, color.Bold).SprintFunc()
fmt.Print("\n")
fmt.Println(" ", caretType("▸"), " Starting Login")
Expand Down
Loading