From f522c4134e12f2ac30c837c4355dfff1b1c6200c Mon Sep 17 00:00:00 2001 From: dbolduc Date: Thu, 9 Jan 2025 17:34:37 -0500 Subject: [PATCH] cleanup(auth): simplify error creation --- src/auth/src/credentials.rs | 31 ++++++++------------- src/auth/src/credentials/mds_credential.rs | 19 ++++++------- src/auth/src/credentials/user_credential.rs | 19 ++++++------- src/auth/src/errors.rs | 10 +++++++ 4 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/auth/src/credentials.rs b/src/auth/src/credentials.rs index 65026a59..30d2273f 100644 --- a/src/auth/src/credentials.rs +++ b/src/auth/src/credentials.rs @@ -250,24 +250,18 @@ pub async fn create_access_token_credential() -> Result { AdcContents::FallbackToMds => return Ok(mds_credential::new()), }; let js: serde_json::Value = - serde_json::from_str(&contents).map_err(|e| CredentialError::new(false, e.into()))?; + serde_json::from_str(&contents).map_err(CredentialError::non_retryable)?; let cred_type = js .get("type") - .ok_or_else(|| CredentialError::new( - false, - Box::from("Failed to parse Application Default Credentials (ADC). No `type` field found."), - ))? + .ok_or_else(|| CredentialError::non_retryable("Failed to parse Application Default Credentials (ADC). No `type` field found."))? .as_str() - .ok_or_else(|| CredentialError::new( - false, - Box::from("Failed to parse Application Default Credentials (ADC). `type` field is not a string."), - ))?; + .ok_or_else(|| CredentialError::non_retryable("Failed to parse Application Default Credentials (ADC). `type` field is not a string.") + )?; match cred_type { "authorized_user" => user_credential::creds_from(js), - _ => Err(CredentialError::new( - false, - Box::from(format!("Unimplemented credential type: {cred_type}")), - )), + _ => Err(CredentialError::non_retryable(format!( + "Unimplemented credential type: {cred_type}" + ))), } } @@ -284,11 +278,10 @@ enum AdcContents { } fn path_not_found(path: String) -> CredentialError { - CredentialError::new( - false, - Box::from(format!( + CredentialError::non_retryable( + format!( "Failed to load Application Default Credentials (ADC) from {path}. Check that the `GOOGLE_APPLICATION_CREDENTIALS` environment variable points to a valid file." - ))) + )) } fn load_adc() -> Result { @@ -297,12 +290,12 @@ fn load_adc() -> Result { Some(AdcPath::FromEnv(path)) => match std::fs::read_to_string(&path) { Ok(contents) => Ok(AdcContents::Contents(contents)), Err(e) if e.kind() == std::io::ErrorKind::NotFound => Err(path_not_found(path)), - Err(e) => Err(CredentialError::new(false, e.into())), + Err(e) => Err(CredentialError::non_retryable(e)), }, Some(AdcPath::WellKnown(path)) => match std::fs::read_to_string(path) { Ok(contents) => Ok(AdcContents::Contents(contents)), Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(AdcContents::FallbackToMds), - Err(e) => Err(CredentialError::new(false, e.into())), + Err(e) => Err(CredentialError::non_retryable(e)), }, } } diff --git a/src/auth/src/credentials/mds_credential.rs b/src/auth/src/credentials/mds_credential.rs index b1a3b579..5d4a353f 100644 --- a/src/auth/src/credentials/mds_credential.rs +++ b/src/auth/src/credentials/mds_credential.rs @@ -58,7 +58,7 @@ where async fn get_headers(&self) -> Result> { let token = self.get_token().await?; let mut value = HeaderValue::from_str(&format!("{} {}", token.token_type, token.token)) - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; value.set_sensitive(true); Ok(vec![(AUTHORIZATION, value)]) } @@ -109,19 +109,19 @@ impl MDSAccessTokenProvider { ); let url = reqwest::Url::parse_with_params(path.as_str(), params.iter()) - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; let response = request .get(url.clone()) .headers(headers) .send() .await - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; response .json::() .await - .map_err(|e| CredentialError::new(false, e.into())) + .map_err(CredentialError::non_retryable) } } @@ -139,10 +139,7 @@ impl TokenProvider for MDSAccessTokenProvider { HeaderValue::from_static(METADATA_FLAVOR_VALUE), ); - let response = request - .send() - .await - .map_err(|e| CredentialError::new(true, e.into()))?; + let response = request.send().await.map_err(CredentialError::retryable)?; // Process the response if !response.status().is_success() { let status = response.status(); @@ -158,7 +155,7 @@ impl TokenProvider for MDSAccessTokenProvider { let response = response .json::() .await - .map_err(|e| CredentialError::new(true, e.into()))?; + .map_err(CredentialError::retryable)?; let token = Token { token: response.access_token, token_type: response.token_type, @@ -210,7 +207,7 @@ mod test { let mut mock = MockTokenProvider::new(); mock.expect_get_token() .times(1) - .return_once(|| Err(CredentialError::new(false, Box::from("fail")))); + .return_once(|| Err(CredentialError::non_retryable("fail"))); let mdsc = MDSCredential { token_provider: mock, @@ -267,7 +264,7 @@ mod test { let mut mock = MockTokenProvider::new(); mock.expect_get_token() .times(1) - .return_once(|| Err(CredentialError::new(false, Box::from("fail")))); + .return_once(|| Err(CredentialError::non_retryable("fail"))); let mdsc = MDSCredential { token_provider: mock, diff --git a/src/auth/src/credentials/user_credential.rs b/src/auth/src/credentials/user_credential.rs index 2446568e..e30f16e8 100644 --- a/src/auth/src/credentials/user_credential.rs +++ b/src/auth/src/credentials/user_credential.rs @@ -25,8 +25,8 @@ use time::OffsetDateTime; const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token"; pub(crate) fn creds_from(js: serde_json::Value) -> Result { - let au = serde_json::from_value::(js) - .map_err(|e| CredentialError::new(false, e.into()))?; + let au = + serde_json::from_value::(js).map_err(CredentialError::non_retryable)?; let token_provider = UserTokenProvider { client_id: au.client_id, client_secret: au.client_secret, @@ -81,15 +81,12 @@ impl TokenProvider for UserTokenProvider { let resp = builder .send() .await - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; // Process the response if !resp.status().is_success() { let status = resp.status(); - let body = resp - .text() - .await - .map_err(|e| CredentialError::new(false, e.into()))?; + let body = resp.text().await.map_err(CredentialError::non_retryable)?; return Err(CredentialError::new( is_retryable(status), Box::from(format!("Failed to fetch token. {body}")), @@ -98,7 +95,7 @@ impl TokenProvider for UserTokenProvider { let response = resp .json::() .await - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; let token = Token { token: response.access_token, token_type: response.token_type, @@ -135,7 +132,7 @@ where async fn get_headers(&self) -> Result> { let token = self.get_token().await?; let mut value = HeaderValue::from_str(&format!("{} {}", token.token_type, token.token)) - .map_err(|e| CredentialError::new(false, e.into()))?; + .map_err(CredentialError::non_retryable)?; value.set_sensitive(true); let mut headers = vec![(AUTHORIZATION, value)]; if let Some(project) = &self.quota_project_id { @@ -309,7 +306,7 @@ mod test { let mut mock = MockTokenProvider::new(); mock.expect_get_token() .times(1) - .return_once(|| Err(CredentialError::new(false, Box::from("fail")))); + .return_once(|| Err(CredentialError::non_retryable("fail"))); let uc = UserCredential { token_provider: mock, @@ -369,7 +366,7 @@ mod test { let mut mock = MockTokenProvider::new(); mock.expect_get_token() .times(1) - .return_once(|| Err(CredentialError::new(false, Box::from("fail")))); + .return_once(|| Err(CredentialError::non_retryable("fail"))); let uc = UserCredential { token_provider: mock, diff --git a/src/auth/src/errors.rs b/src/auth/src/errors.rs index 52d0faa2..892c6d69 100644 --- a/src/auth/src/errors.rs +++ b/src/auth/src/errors.rs @@ -57,6 +57,16 @@ impl CredentialError { pub fn is_retryable(&self) -> bool { self.is_retryable } + + /// A helper to create a retryable error. + pub(crate) fn retryable>(source: T) -> Self { + CredentialError::new(true, source.into()) + } + + /// A helper to create a non-retryable error. + pub(crate) fn non_retryable>(source: T) -> Self { + CredentialError::new(false, source.into()) + } } impl std::error::Error for CredentialError {