Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Add oauth redirect listener #634

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions auth/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package auth

import (
"context"
"fmt"
"net"
"net/http"
"os"
"time"

"golang.org/x/oauth2"
)

type authorize struct{ authUrl string }
type callback struct {
done chan string
bad chan bool
state string
}

func (a authorize) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Location", a.authUrl)
w.WriteHeader(302)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Redirect to authentication server</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintf(w, "Click <a href=\"%s\">here</a> to authorize gdrive to use Google Drive\n",
a.authUrl)
fmt.Fprintln(w, "</body></html>")
}

func (c callback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
fmt.Printf("Could not parse form on /callback: %s\n", err)
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: Missing authentication response")
fmt.Fprintln(w, "</body></html>")
return
}
if req.Form.Get("error") != "" {
fmt.Printf("authentication failed, server response is %s\n", req.Form.Get("error"))
c.bad <- true
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Google Drive authentication failed</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintf(w, "Authentication failed or refused: %s\n", req.Form.Get("error"))
fmt.Fprintln(w, "</body></html>")
return
}

if req.Form.Get("code") == "" || req.Form.Get("state") == "" {
fmt.Println("callback request is missing parameters")
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: response is missing the code or state parameters")
fmt.Fprintln(w, "</body></html>")
return
}

code := req.Form.Get("code")
state := req.Form.Get("state")
if state != c.state {
fmt.Printf("Callback state mismatch: %s vs %s", state, c.state)
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: response state mismatch")
fmt.Fprintln(w, "</body></html>")
return
}
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Authentication response received</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Authentication response has been received. Check the terminal where gdrive is running")
fmt.Fprintln(w, "</body></html>")

c.done <- code
}

func AuthCodeHTTP(conf *oauth2.Config, state string) (func() (string, error), error) {

ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return nil, err
}

hostPort := ln.Addr().String()
_, port, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
}

mux := http.NewServeMux()
srv := &http.Server{Handler: mux}

go func() {
err := srv.Serve(ln)
if err != http.ErrServerClosed {
fmt.Printf("Cannot start http server: %s", err)
os.Exit(1)
}
}()
myconf := conf
myconf.RedirectURL = fmt.Sprintf("http://127.0.0.1:%s/callback", port)

authUrl := myconf.AuthCodeURL(state, oauth2.AccessTypeOffline)
authorizer := authorize{authUrl: authUrl}
mux.Handle("/authorize", authorizer)
callback := callback{state: state,
done: make(chan string, 1),
bad: make(chan bool, 1),
}
mux.Handle("/callback", callback)

return func() (string, error) {
var code string
var err error
fmt.Println("Authentication needed")
fmt.Println("Go to the following url in your browser:")
fmt.Printf("http://127.0.0.1:%s/authorize\n\n", port)
fmt.Println("Waiting for authentication response")

select {
case <-callback.bad:
err = fmt.Errorf("authentication did not complete successfully")
code = ""
case code = <-callback.done:
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancel()
}()

if stoperr := srv.Shutdown(ctx); stoperr != nil {
fmt.Printf("Server Shutdown Failed:%+v\n", stoperr)
}
return code, err
}, nil
}
63 changes: 54 additions & 9 deletions auth/oauth.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package auth

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"io"
"net/http"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

type authCodeFn func(string) func() string
type authCodeFn func(*oauth2.Config, string) (func() (string, error), error)

func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCodeFn) (*http.Client, error) {
conf := getConfig(clientId, clientSecret)
Expand All @@ -22,11 +27,21 @@ func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCo
// Require auth code if token file does not exist
// or refresh token is missing
if !exists || token.RefreshToken == "" {
authUrl := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
authCode := authFn(authUrl)()
state, err := makeState()
if err != nil {
return nil, fmt.Errorf("could not build state string: %s", err)
}
authFnInt, err := authFn(conf, state)
if err != nil {
return nil, fmt.Errorf("could not receive auth code: %s", err)
}
authCode, err := authFnInt()
if err != nil {
return nil, fmt.Errorf("could not receive auth code: %s", err)
}
token, err = conf.Exchange(oauth2.NoContext, authCode)
if err != nil {
return nil, fmt.Errorf("Failed to exchange auth code for token: %s", err)
return nil, fmt.Errorf("failed to exchange auth code for token: %s", err)
}
}

Expand Down Expand Up @@ -67,16 +82,16 @@ func NewAccessTokenClient(clientId, clientSecret, accessToken string) *http.Clie

func NewServiceAccountClient(serviceAccountFile string) (*http.Client, error) {
content, exists, err := ReadFile(serviceAccountFile)
if(!exists) {
if !exists {
return nil, fmt.Errorf("Service account filename %q not found", serviceAccountFile)
}

if(err != nil) {
if err != nil {
return nil, err
}

conf, err := google.JWTConfigFromJSON(content, "https://www.googleapis.com/auth/drive")
if(err != nil) {
if err != nil {
return nil, err
}
return conf.Client(oauth2.NoContext), nil
Expand All @@ -94,3 +109,33 @@ func getConfig(clientId, clientSecret string) *oauth2.Config {
},
}
}

func makeState() (string, error) {
return makeString(12)
}

func makeCodeChallenge() (string, string, error) {
verifier, err := makeString(48)
if err != nil {
return "", "", err
}

hasher := sha256.New()
_, err = hasher.Write([]byte(verifier))
if err != nil {
return "", "", err
}

hash := hasher.Sum(nil)
challenge := base64.RawURLEncoding.EncodeToString(hash)

return verifier, challenge, nil
}

func makeString(n int) (string, error) {
data := make([]byte, n)
if _, err := io.ReadFull(rand.Reader, data); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(data), nil
}
18 changes: 1 addition & 17 deletions handlers_drive.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"fmt"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -365,7 +364,7 @@ func getOauthClient(args cli.Arguments) (*http.Client, error) {
}

tokenPath := ConfigFilePath(configDir, TokenFilename)
return auth.NewFileSourceClient(ClientId, ClientSecret, tokenPath, authCodePrompt)
return auth.NewFileSourceClient(ClientId, ClientSecret, tokenPath, auth.AuthCodeHTTP)
}

func getConfigDir(args cli.Arguments) string {
Expand All @@ -390,21 +389,6 @@ func newDrive(args cli.Arguments) *drive.Drive {
return client
}

func authCodePrompt(url string) func() string {
return func() string {
fmt.Println("Authentication needed")
fmt.Println("Go to the following url in your browser:")
fmt.Printf("%s\n\n", url)
fmt.Print("Enter verification code: ")

var code string
if _, err := fmt.Scan(&code); err != nil {
fmt.Printf("Failed reading code: %s", err.Error())
}
return code
}
}

func progressWriter(discard bool) io.Writer {
if discard {
return ioutil.Discard
Expand Down