Skip to content

Commit

Permalink
feat(auth): support quota project in UserCredentials (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolduc authored Jan 10, 2025
1 parent 0db355e commit 1b7ee25
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/auth/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use http::header::{HeaderName, HeaderValue};
use std::future::Future;
use std::sync::Arc;

pub(crate) const QUOTA_PROJECT_KEY: &str = "x-goog-user-project";

/// An implementation of [crate::credentials::CredentialTrait].
///
/// Represents a [Credential] used to obtain auth [Token][crate::token::Token]s
Expand Down
109 changes: 94 additions & 15 deletions src/auth/src/credentials/user_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
// limitations under the License.

use crate::credentials::dynamic::CredentialTrait;
use crate::credentials::Credential;
use crate::credentials::Result;
use crate::credentials::{Credential, Result, QUOTA_PROJECT_KEY};
use crate::errors::{is_retryable, CredentialError};
use crate::token::{Token, TokenProvider};
use http::header::{HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
Expand All @@ -36,7 +35,10 @@ pub(crate) fn creds_from(js: serde_json::Value) -> Result<Credential> {
};

Ok(Credential {
inner: Arc::new(UserCredential { token_provider }),
inner: Arc::new(UserCredential {
token_provider,
quota_project_id: au.quota_project_id,
}),
})
}

Expand Down Expand Up @@ -118,6 +120,7 @@ where
T: TokenProvider,
{
token_provider: T,
quota_project_id: Option<String>,
}

#[async_trait::async_trait]
Expand All @@ -134,7 +137,15 @@ where
let mut value = HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))
.map_err(|e| CredentialError::new(false, e.into()))?;
value.set_sensitive(true);
Ok(vec![(AUTHORIZATION, value)])
let mut headers = vec![(AUTHORIZATION, value)];
if let Some(project) = &self.quota_project_id {
headers.push((
HeaderName::from_static(QUOTA_PROJECT_KEY),
HeaderValue::from_str(project)
.map_err(|e| CredentialError::new(false, e.into()))?,
));
}
Ok(headers)
}

async fn get_universe_domain(&self) -> Option<String> {
Expand All @@ -149,6 +160,8 @@ pub(crate) struct AuthorizedUser {
client_id: String,
client_secret: String,
refresh_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
quota_project_id: Option<String>,
}

#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -220,6 +233,7 @@ mod test {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
quota_project_id: Some("test-project".to_string()),
};
let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
assert_eq!(actual, expected);
Expand All @@ -239,6 +253,7 @@ mod test {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
quota_project_id: None,
};
let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
assert_eq!(actual, expected);
Expand Down Expand Up @@ -283,6 +298,7 @@ mod test {

let uc = UserCredential {
token_provider: mock,
quota_project_id: None,
};
let actual = uc.get_token().await.unwrap();
assert_eq!(actual, expected);
Expand All @@ -297,19 +313,21 @@ mod test {

let uc = UserCredential {
token_provider: mock,
quota_project_id: None,
};
assert!(uc.get_token().await.is_err());
}

// Convenience struct for verifying (HeaderName, HeaderValue) pairs.
#[derive(Debug, Eq, Ord, PartialEq, PartialOrd)]
struct HV {
header: String,
value: String,
is_sensitive: bool,
}

#[tokio::test]
async fn get_headers_success() {
#[derive(Debug, PartialEq)]
struct HV {
header: String,
value: String,
is_sensitive: bool,
}

let token = Token {
token: "test-token".to_string(),
token_type: "Bearer".to_string(),
Expand All @@ -322,6 +340,7 @@ mod test {

let uc = UserCredential {
token_provider: mock,
quota_project_id: None,
};
let headers: Vec<HV> = uc
.get_headers()
Expand Down Expand Up @@ -354,10 +373,58 @@ mod test {

let uc = UserCredential {
token_provider: mock,
quota_project_id: None,
};
assert!(uc.get_headers().await.is_err());
}

#[tokio::test]
async fn get_headers_with_quota_project_success() {
let token = Token {
token: "test-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: None,
metadata: None,
};

let mut mock = MockTokenProvider::new();
mock.expect_get_token().times(1).return_once(|| Ok(token));

let uc = UserCredential {
token_provider: mock,
quota_project_id: Some("test-project".to_string()),
};
let mut headers: Vec<HV> = uc
.get_headers()
.await
.unwrap()
.into_iter()
.map(|(h, v)| HV {
header: h.to_string(),
value: v.to_str().unwrap().to_string(),
is_sensitive: v.is_sensitive(),
})
.collect();

// The ordering of the headers does not matter.
headers.sort();
assert_eq!(
headers,
vec![
HV {
header: AUTHORIZATION.to_string(),
value: "Bearer test-token".to_string(),
is_sensitive: true,
},
HV {
header: QUOTA_PROJECT_KEY.to_string(),
value: "test-project".to_string(),
is_sensitive: false,
}
]
);
}

#[test]
fn oauth2_request_serde() {
let request = Oauth2RefreshRequest {
Expand Down Expand Up @@ -476,7 +543,10 @@ mod test {
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider };
let uc = UserCredential {
token_provider,
quota_project_id: None,
};
let now = OffsetDateTime::now_utc();
let token = uc.get_token().await?;
assert_eq!(token.token, "test-access-token");
Expand Down Expand Up @@ -507,7 +577,10 @@ mod test {
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider };
let uc = UserCredential {
token_provider,
quota_project_id: None,
};
let token = uc.get_token().await?;
assert_eq!(token.token, "test-access-token");
assert_eq!(token.token_type, "test-token-type");
Expand All @@ -528,7 +601,10 @@ mod test {
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider };
let uc = UserCredential {
token_provider,
quota_project_id: None,
};
let e = uc.get_token().await.err().unwrap();
assert!(e.is_retryable());
assert!(e.source().unwrap().to_string().contains("try again"));
Expand All @@ -547,7 +623,10 @@ mod test {
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider };
let uc = UserCredential {
token_provider,
quota_project_id: None,
};
let e = uc.get_token().await.err().unwrap();
assert!(!e.is_retryable());
assert!(e.source().unwrap().to_string().contains("epic fail"));
Expand Down

0 comments on commit 1b7ee25

Please sign in to comment.