diff --git a/auth/oauth.go b/auth/oauth.go index bc567385..3a3da407 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -2,6 +2,7 @@ package auth import ( "fmt" + "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "net/http" @@ -10,7 +11,7 @@ import ( type authCodeFn func(string) func() string -func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCodeFn) (*http.Client, error) { +func NewFileSourceClient(clientId, clientSecret string, ctx context.Context, tokenFile string, authFn authCodeFn) (*http.Client, error) { conf := getConfig(clientId, clientSecret) // Read cached token @@ -31,12 +32,12 @@ func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCo } return oauth2.NewClient( - oauth2.NoContext, + ctx, FileSource(tokenFile, token, conf), ), nil } -func NewRefreshTokenClient(clientId, clientSecret, refreshToken string) *http.Client { +func NewRefreshTokenClient(clientId, clientSecret string, ctx context.Context, refreshToken string) *http.Client { conf := getConfig(clientId, clientSecret) token := &oauth2.Token{ @@ -46,12 +47,12 @@ func NewRefreshTokenClient(clientId, clientSecret, refreshToken string) *http.Cl } return oauth2.NewClient( - oauth2.NoContext, + ctx, conf.TokenSource(oauth2.NoContext, token), ) } -func NewAccessTokenClient(clientId, clientSecret, accessToken string) *http.Client { +func NewAccessTokenClient(clientId, clientSecret string, ctx context.Context, accessToken string) *http.Client { conf := getConfig(clientId, clientSecret) token := &oauth2.Token{ @@ -60,12 +61,12 @@ func NewAccessTokenClient(clientId, clientSecret, accessToken string) *http.Clie } return oauth2.NewClient( - oauth2.NoContext, + ctx, conf.TokenSource(oauth2.NoContext, token), ) } -func NewServiceAccountClient(serviceAccountFile string) (*http.Client, error) { +func NewServiceAccountClient(serviceAccountFile string, ctx context.Context) (*http.Client, error) { content, exists, err := ReadFile(serviceAccountFile) if(!exists) { return nil, fmt.Errorf("Service account filename %q not found", serviceAccountFile) @@ -79,7 +80,7 @@ func NewServiceAccountClient(serviceAccountFile string) (*http.Client, error) { if(err != nil) { return nil, err } - return conf.Client(oauth2.NoContext), nil + return conf.Client(ctx), nil } func getConfig(clientId, clientSecret string) *oauth2.Config { diff --git a/gdrive.go b/gdrive.go index a505d789..492ca666 100644 --- a/gdrive.go +++ b/gdrive.go @@ -45,6 +45,11 @@ func main() { Patterns: []string{"--service-account"}, Description: "Oauth service account filename, used for server to server communication without user interaction (filename path is relative to config dir)", }, + cli.BoolFlag{ + Name: "disable-compression", + Patterns: []string{"--disable-compression"}, + Description: "Disable gzip compression in HTTP requests. This might be useful to trade higher bandwidth for reduced CPU.", + OmitValue: true}, } handlers := []*cli.Handler{ diff --git a/handlers_drive.go b/handlers_drive.go index 7bda872f..e1d7ac52 100644 --- a/handlers_drive.go +++ b/handlers_drive.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "golang.org/x/net/context" + "golang.org/x/oauth2" "io" "io/ioutil" "net/http" @@ -345,19 +347,25 @@ func getOauthClient(args cli.Arguments) (*http.Client, error) { ExitF("Access token not needed when refresh token is provided") } + oauth_context := context.TODO() + if args.Bool("disable-compression") { + oauth_context = context.WithValue(oauth_context, oauth2.HTTPClient, + &http.Client{Transport: &http.Transport{DisableCompression: true}}) + } + if args.String("refreshToken") != "" { - return auth.NewRefreshTokenClient(ClientId, ClientSecret, args.String("refreshToken")), nil + return auth.NewRefreshTokenClient(ClientId, ClientSecret, oauth_context, args.String("refreshToken")), nil } if args.String("accessToken") != "" { - return auth.NewAccessTokenClient(ClientId, ClientSecret, args.String("accessToken")), nil + return auth.NewAccessTokenClient(ClientId, ClientSecret, oauth_context, args.String("accessToken")), nil } configDir := getConfigDir(args) if args.String("serviceAccount") != "" { serviceAccountPath := ConfigFilePath(configDir, args.String("serviceAccount")) - serviceAccountClient, err := auth.NewServiceAccountClient(serviceAccountPath) + serviceAccountClient, err := auth.NewServiceAccountClient(serviceAccountPath, oauth_context) if err != nil { return nil, err } @@ -365,7 +373,7 @@ func getOauthClient(args cli.Arguments) (*http.Client, error) { } tokenPath := ConfigFilePath(configDir, TokenFilename) - return auth.NewFileSourceClient(ClientId, ClientSecret, tokenPath, authCodePrompt) + return auth.NewFileSourceClient(ClientId, ClientSecret, oauth_context, tokenPath, authCodePrompt) } func getConfigDir(args cli.Arguments) string {