Skip to content

Commit

Permalink
cleanup(auth): simplify error creation
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolduc committed Jan 10, 2025
1 parent 1b7ee25 commit f522c41
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 41 deletions.
31 changes: 12 additions & 19 deletions src/auth/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,24 +250,18 @@ pub async fn create_access_token_credential() -> Result<Credential> {
AdcContents::FallbackToMds => return Ok(mds_credential::new()),
};
let js: serde_json::Value =
serde_json::from_str(&contents).map_err(|e| CredentialError::new(false, e.into()))?;
serde_json::from_str(&contents).map_err(CredentialError::non_retryable)?;
let cred_type = js
.get("type")
.ok_or_else(|| CredentialError::new(
false,
Box::from("Failed to parse Application Default Credentials (ADC). No `type` field found."),
))?
.ok_or_else(|| CredentialError::non_retryable("Failed to parse Application Default Credentials (ADC). No `type` field found."))?
.as_str()
.ok_or_else(|| CredentialError::new(
false,
Box::from("Failed to parse Application Default Credentials (ADC). `type` field is not a string."),
))?;
.ok_or_else(|| CredentialError::non_retryable("Failed to parse Application Default Credentials (ADC). `type` field is not a string.")
)?;
match cred_type {
"authorized_user" => user_credential::creds_from(js),
_ => Err(CredentialError::new(
false,
Box::from(format!("Unimplemented credential type: {cred_type}")),
)),
_ => Err(CredentialError::non_retryable(format!(
"Unimplemented credential type: {cred_type}"
))),
}
}

Expand All @@ -284,11 +278,10 @@ enum AdcContents {
}

fn path_not_found(path: String) -> CredentialError {
CredentialError::new(
false,
Box::from(format!(
CredentialError::non_retryable(
format!(
"Failed to load Application Default Credentials (ADC) from {path}. Check that the `GOOGLE_APPLICATION_CREDENTIALS` environment variable points to a valid file."
)))
))
}

fn load_adc() -> Result<AdcContents> {
Expand All @@ -297,12 +290,12 @@ fn load_adc() -> Result<AdcContents> {
Some(AdcPath::FromEnv(path)) => match std::fs::read_to_string(&path) {
Ok(contents) => Ok(AdcContents::Contents(contents)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Err(path_not_found(path)),
Err(e) => Err(CredentialError::new(false, e.into())),
Err(e) => Err(CredentialError::non_retryable(e)),
},
Some(AdcPath::WellKnown(path)) => match std::fs::read_to_string(path) {
Ok(contents) => Ok(AdcContents::Contents(contents)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(AdcContents::FallbackToMds),
Err(e) => Err(CredentialError::new(false, e.into())),
Err(e) => Err(CredentialError::non_retryable(e)),
},
}
}
Expand Down
19 changes: 8 additions & 11 deletions src/auth/src/credentials/mds_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where
async fn get_headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
let token = self.get_token().await?;
let mut value = HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;
value.set_sensitive(true);
Ok(vec![(AUTHORIZATION, value)])
}
Expand Down Expand Up @@ -109,19 +109,19 @@ impl MDSAccessTokenProvider {
);

let url = reqwest::Url::parse_with_params(path.as_str(), params.iter())
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;

let response = request
.get(url.clone())
.headers(headers)
.send()
.await
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;

response
.json::<ServiceAccountInfo>()
.await
.map_err(|e| CredentialError::new(false, e.into()))
.map_err(CredentialError::non_retryable)
}
}

Expand All @@ -139,10 +139,7 @@ impl TokenProvider for MDSAccessTokenProvider {
HeaderValue::from_static(METADATA_FLAVOR_VALUE),
);

let response = request
.send()
.await
.map_err(|e| CredentialError::new(true, e.into()))?;
let response = request.send().await.map_err(CredentialError::retryable)?;
// Process the response
if !response.status().is_success() {
let status = response.status();
Expand All @@ -158,7 +155,7 @@ impl TokenProvider for MDSAccessTokenProvider {
let response = response
.json::<MDSTokenResponse>()
.await
.map_err(|e| CredentialError::new(true, e.into()))?;
.map_err(CredentialError::retryable)?;
let token = Token {
token: response.access_token,
token_type: response.token_type,
Expand Down Expand Up @@ -210,7 +207,7 @@ mod test {
let mut mock = MockTokenProvider::new();
mock.expect_get_token()
.times(1)
.return_once(|| Err(CredentialError::new(false, Box::from("fail"))));
.return_once(|| Err(CredentialError::non_retryable("fail")));

let mdsc = MDSCredential {
token_provider: mock,
Expand Down Expand Up @@ -267,7 +264,7 @@ mod test {
let mut mock = MockTokenProvider::new();
mock.expect_get_token()
.times(1)
.return_once(|| Err(CredentialError::new(false, Box::from("fail"))));
.return_once(|| Err(CredentialError::non_retryable("fail")));

let mdsc = MDSCredential {
token_provider: mock,
Expand Down
19 changes: 8 additions & 11 deletions src/auth/src/credentials/user_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use time::OffsetDateTime;
const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";

pub(crate) fn creds_from(js: serde_json::Value) -> Result<Credential> {
let au = serde_json::from_value::<AuthorizedUser>(js)
.map_err(|e| CredentialError::new(false, e.into()))?;
let au =
serde_json::from_value::<AuthorizedUser>(js).map_err(CredentialError::non_retryable)?;
let token_provider = UserTokenProvider {
client_id: au.client_id,
client_secret: au.client_secret,
Expand Down Expand Up @@ -81,15 +81,12 @@ impl TokenProvider for UserTokenProvider {
let resp = builder
.send()
.await
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;

// Process the response
if !resp.status().is_success() {
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| CredentialError::new(false, e.into()))?;
let body = resp.text().await.map_err(CredentialError::non_retryable)?;
return Err(CredentialError::new(
is_retryable(status),
Box::from(format!("Failed to fetch token. {body}")),
Expand All @@ -98,7 +95,7 @@ impl TokenProvider for UserTokenProvider {
let response = resp
.json::<Oauth2RefreshResponse>()
.await
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;
let token = Token {
token: response.access_token,
token_type: response.token_type,
Expand Down Expand Up @@ -135,7 +132,7 @@ where
async fn get_headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
let token = self.get_token().await?;
let mut value = HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))
.map_err(|e| CredentialError::new(false, e.into()))?;
.map_err(CredentialError::non_retryable)?;
value.set_sensitive(true);
let mut headers = vec![(AUTHORIZATION, value)];
if let Some(project) = &self.quota_project_id {
Expand Down Expand Up @@ -309,7 +306,7 @@ mod test {
let mut mock = MockTokenProvider::new();
mock.expect_get_token()
.times(1)
.return_once(|| Err(CredentialError::new(false, Box::from("fail"))));
.return_once(|| Err(CredentialError::non_retryable("fail")));

let uc = UserCredential {
token_provider: mock,
Expand Down Expand Up @@ -369,7 +366,7 @@ mod test {
let mut mock = MockTokenProvider::new();
mock.expect_get_token()
.times(1)
.return_once(|| Err(CredentialError::new(false, Box::from("fail"))));
.return_once(|| Err(CredentialError::non_retryable("fail")));

let uc = UserCredential {
token_provider: mock,
Expand Down
10 changes: 10 additions & 0 deletions src/auth/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ impl CredentialError {
pub fn is_retryable(&self) -> bool {
self.is_retryable
}

/// A helper to create a retryable error.
pub(crate) fn retryable<T: Into<BoxError>>(source: T) -> Self {
CredentialError::new(true, source.into())
}

/// A helper to create a non-retryable error.
pub(crate) fn non_retryable<T: Into<BoxError>>(source: T) -> Self {
CredentialError::new(false, source.into())
}
}

impl std::error::Error for CredentialError {
Expand Down

0 comments on commit f522c41

Please sign in to comment.