Skip to content

Commit

Permalink
More invariants
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Aug 28, 2024
1 parent 7a69e63 commit cc3bfeb
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
7 changes: 7 additions & 0 deletions fidget/src/core/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ pub trait Tape {

/// Returns a mapping from [`Var`](crate::var::Var) to evaluation index
fn vars(&self) -> &VarMap;

/// Returns the number of outputs written by this tape
///
/// The order of outputs is set by the caller at tape construction, so we
/// don't need a map to determine the index of a particular output (unlike
/// variables).
fn output_count(&self) -> usize;
}

/// Represents the trace captured by a tracing evaluation
Expand Down
12 changes: 12 additions & 0 deletions fidget/src/core/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,12 @@ where
z: F,
vars: &HashMap<VarIndex, F>,
) -> Result<(E::Data, Option<&E::Trace>), Error> {
assert_eq!(
tape.tape.output_count(),
1,
"ShapeTape has multiple outputs"
);

let x = x.into();
let y = y.into();
let z = z.into();
Expand Down Expand Up @@ -497,6 +503,12 @@ where
z: &[E::Data],
vars: &HashMap<VarIndex, V>,
) -> Result<&[E::Data], Error> {
assert_eq!(
tape.tape.output_count(),
1,
"ShapeTape has multiple outputs"
);

// Make sure our scratch arrays are big enough for this evaluation
if x.len() != y.len() || x.len() != z.len() {
return Err(Error::MismatchedSlices);
Expand Down
9 changes: 4 additions & 5 deletions fidget/src/core/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl<const N: usize> Tape for GenericVmFunction<N> {
fn vars(&self) -> &VarMap {
&self.0.vars
}

fn output_count(&self) -> usize {
self.0.output_count()
}
}

/// A trace captured by a VM evaluation
Expand Down Expand Up @@ -140,11 +144,6 @@ impl<const N: usize> GenericVmFunction<N> {
pub fn choice_count(&self) -> usize {
self.0.choice_count()
}

/// Returns the number of output clauses in the tape
pub fn output_count(&self) -> usize {
self.0.output_count()
}
}

impl<const N: usize> Function for GenericVmFunction<N> {
Expand Down
10 changes: 10 additions & 0 deletions fidget/src/jit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ impl JitFunction {
let ptr = f.as_ptr();
JitBulkFn {
mmap: f,
output_count: self.0.output_count(),
vars: self.0.data().vars.clone(),
fn_bulk: unsafe { std::mem::transmute(ptr) },
}
Expand Down Expand Up @@ -998,6 +999,10 @@ impl<T> Tape for JitTracingFn<T> {
fn vars(&self) -> &VarMap {
&self.vars
}

fn output_count(&self) -> usize {
self.output_count
}
}

// SAFETY: there is no mutable state in a `JitTracingFn`, and the pointer
Expand Down Expand Up @@ -1082,6 +1087,7 @@ pub struct JitBulkFn<T> {
#[allow(unused)]
mmap: Mmap,
vars: Arc<VarMap>,
output_count: usize,
fn_bulk: jit_fn!(
unsafe fn(
*const *const T, // vars
Expand All @@ -1100,6 +1106,10 @@ impl<T> Tape for JitBulkFn<T> {
fn vars(&self) -> &VarMap {
&self.vars
}

fn output_count(&self) -> usize {
self.output_count
}
}

/// Maximum SIMD width for any type, checked at runtime (alas)
Expand Down

0 comments on commit cc3bfeb

Please sign in to comment.