Skip to content

Commit

Permalink
feat!: add recoverable errors to checks
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <gustavo@semiotic.ai>
  • Loading branch information
gusinacio committed Aug 14, 2024
1 parent 04b134a commit ce4dd19
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 28 deletions.
32 changes: 21 additions & 11 deletions tap_core/src/manager/context/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl EscrowHandler for InMemoryContext {
pub mod checks {
use crate::{
receipt::{
checks::{Check, CheckResult, ReceiptCheck},
checks::{Check, CheckError, CheckResult, ReceiptCheck},
state::Checking,
ReceiptError, ReceiptWithState,
},
Expand Down Expand Up @@ -306,10 +306,12 @@ pub mod checks {
{
Ok(())
} else {
Err(ReceiptError::InvalidAllocationID {
received_allocation_id,
}
.into())
Err(CheckError::Failure(
ReceiptError::InvalidAllocationID {
received_allocation_id,
}
.into(),
))
}
}
}
Expand All @@ -325,14 +327,22 @@ pub mod checks {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
.map_err(|e| ReceiptError::InvalidSignature {
source_error_message: e.to_string(),
.map_err(|e| {
CheckError::Failure(
ReceiptError::InvalidSignature {
source_error_message: e.to_string(),
}
.into(),
)
})?;

if !self.valid_signers.contains(&recovered_address) {
Err(ReceiptError::InvalidSignature {
source_error_message: "Invalid signer".to_string(),
}
.into())
Err(CheckError::Failure(
ReceiptError::InvalidSignature {
source_error_message: "Invalid signer".to_string(),
}
.into(),
))
} else {
Ok(())
}
Expand Down
7 changes: 5 additions & 2 deletions tap_core/src/manager/tap_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
receipt::{
checks::{CheckBatch, CheckList, TimestampCheck, UniqueCheck},
state::{Failed, Reserved},
ReceiptWithState, SignedReceipt,
ReceiptError, ReceiptWithState, SignedReceipt,
},
Error,
};
Expand Down Expand Up @@ -139,7 +139,10 @@ where
failed_receipts.extend(already_failed);

for receipt in checking_receipts.into_iter() {
let receipt = receipt.finalize_receipt_checks(&self.checks).await;
let receipt = receipt
.finalize_receipt_checks(&self.checks)
.await
.map_err(|e| Error::ReceiptError(ReceiptError::RecoverableCheckError(e)))?;

match receipt {
Ok(checked) => awaiting_reserve_receipts.push(checked),
Expand Down
22 changes: 16 additions & 6 deletions tap_core/src/receipt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ use std::{
pub type ReceiptCheck = Arc<dyn Check + Sync + Send>;

/// Result of a check operation. It uses the `anyhow` crate to handle errors.
pub type CheckResult = anyhow::Result<()>;
pub type CheckResult = Result<(), CheckError>;

#[derive(thiserror::Error, Debug)]
pub enum CheckError {
#[error(transparent)]
Recoverable(anyhow::Error),
#[error(transparent)]
Failure(anyhow::Error),
}

/// CheckList is a NewType pattern to store a list of checks.
/// It is a wrapper around an Arc of ReceiptCheck[].
Expand Down Expand Up @@ -115,11 +123,13 @@ impl Check for StatefulTimestampCheck {
let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap();
let signed_receipt = receipt.signed_receipt();
if signed_receipt.message.timestamp_ns <= min_timestamp_ns {
return Err(ReceiptError::InvalidTimestamp {
received_timestamp: signed_receipt.message.timestamp_ns,
timestamp_min: min_timestamp_ns,
}
.into());
return Err(CheckError::Failure(
ReceiptError::InvalidTimestamp {
received_timestamp: signed_receipt.message.timestamp_ns,
timestamp_min: min_timestamp_ns,
}
.into(),
));
}
Ok(())
}
Expand Down
2 changes: 2 additions & 0 deletions tap_core/src/receipt/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ pub enum ReceiptError {
SubtractEscrowFailed,
#[error("Issue encountered while performing check: {0}")]
CheckFailedToComplete(String),
#[error("Issue encountered while performing check: {0}")]
RecoverableCheckError(String),
}
20 changes: 11 additions & 9 deletions tap_core/src/receipt/received_receipt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use alloy::dyn_abi::Eip712Domain;

use super::checks::CheckError;
use super::{Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use crate::receipt::state::{AwaitingReserve, Checking, Failed, ReceiptState, Reserved};
use crate::{
Expand Down Expand Up @@ -92,10 +93,10 @@ impl ReceiptWithState<Checking> {
pub async fn perform_checks(&mut self, checks: &[ReceiptCheck]) -> ReceiptResult<()> {
for check in checks {
// return early on an error
check
.check(self)
.await
.map_err(|e| ReceiptError::CheckFailedToComplete(e.to_string()))?;
check.check(self).await.map_err(|e| match e {
CheckError::Recoverable(e) => ReceiptError::RecoverableCheckError(e.to_string()),
CheckError::Failure(e) => ReceiptError::CheckFailedToComplete(e.to_string()),
})?;
}
Ok(())
}
Expand All @@ -108,14 +109,15 @@ impl ReceiptWithState<Checking> {
pub async fn finalize_receipt_checks(
mut self,
checks: &[ReceiptCheck],
) -> ResultReceipt<AwaitingReserve> {
) -> Result<ResultReceipt<AwaitingReserve>, String> {
let all_checks_passed = self.perform_checks(checks).await;

if let Err(e) = all_checks_passed {
Err(self.perform_state_error(e))
if let Err(ReceiptError::RecoverableCheckError(e)) = all_checks_passed {
Err(e.to_string())
} else if let Err(e) = all_checks_passed {
Ok(Err(self.perform_state_error(e)))
} else {
let checked = self.perform_state_changes(AwaitingReserve);
Ok(checked)
Ok(Ok(checked))
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions tap_core/tests/received_receipt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async fn partial_then_finalize_valid_receipt(

let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap();
let receipt = awaiting_escrow_receipt
.unwrap()
.check_and_reserve_escrow(&context, &domain_separator)
.await;
assert!(receipt.is_ok());
Expand Down Expand Up @@ -234,6 +235,7 @@ async fn standard_lifetime_valid_receipt(

let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap();
let receipt = awaiting_escrow_receipt
.unwrap()
.check_and_reserve_escrow(&context, &domain_separator)
.await;
assert!(receipt.is_ok());
Expand Down

0 comments on commit ce4dd19

Please sign in to comment.