Skip to content

Commit

Permalink
Add first-class Output opcode to the tape (#155)
Browse files Browse the repository at this point in the history
This is a step towards supporting multiple outputs in a single tape!
  • Loading branch information
mkeeter authored Aug 4, 2024
1 parent b1c1ba1 commit 2207b56
Show file tree
Hide file tree
Showing 17 changed files with 292 additions and 169 deletions.
21 changes: 21 additions & 0 deletions fidget/src/core/compiler/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ impl<const N: usize> RegisterAllocator<N> {
#[inline(always)]
pub fn op(&mut self, op: SsaOp) {
match op {
SsaOp::Output(reg, i) => self.op_output(reg, i),
SsaOp::Input(out, i) => self.op_input(out, i),
SsaOp::CopyImm(out, imm) => self.op_copy_imm(out, imm),

Expand Down Expand Up @@ -676,4 +677,24 @@ impl<const N: usize> RegisterAllocator<N> {
fn op_input(&mut self, out: u32, i: u32) {
self.op_out_only(out, |out| RegOp::Input(out, i));
}

/// Pushes an [`Output`](crate::compiler::RegOp::Output) operation to the
/// tape
#[inline(always)]
fn op_output(&mut self, arg: u32, i: u32) {
match self.get_allocation(arg) {
Allocation::Register(r_y) => self.out.push(RegOp::Output(r_y, i)),
Allocation::Memory(m_y) => {
let r_a = self.get_register();
self.push_store(r_a, m_y);
self.out.push(RegOp::Output(r_a, i));
self.bind_register(arg, r_a);
}
Allocation::Unassigned => {
let r_a = self.get_register();
self.out.push(RegOp::Output(r_a, i));
self.bind_register(arg, r_a);
}
}
}
}
9 changes: 7 additions & 2 deletions fidget/src/core/compiler/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ macro_rules! opcodes {
#[doc = "Read an input variable by index"]
Input($t, u32),

#[doc = "Writes an output variable by index"]
Output($t, u32),

#[doc = "Negate the given register"]
NegReg($t, $t),

Expand Down Expand Up @@ -164,7 +167,7 @@ opcodes!(

impl SsaOp {
/// Returns the output pseudo-register
pub fn output(&self) -> u32 {
pub fn output(&self) -> Option<u32> {
match self {
SsaOp::Input(out, ..)
| SsaOp::CopyImm(out, ..)
Expand Down Expand Up @@ -212,13 +215,15 @@ impl SsaOp {
| SsaOp::AndRegImm(out, ..)
| SsaOp::AndRegReg(out, ..)
| SsaOp::OrRegImm(out, ..)
| SsaOp::OrRegReg(out, ..) => *out,
| SsaOp::OrRegReg(out, ..) => Some(*out),
SsaOp::Output(..) => None,
}
}
/// Returns true if the given opcode is associated with a choice
pub fn has_choice(&self) -> bool {
match self {
SsaOp::Input(..)
| SsaOp::Output(..)
| SsaOp::CopyImm(..)
| SsaOp::NegReg(..)
| SsaOp::AbsReg(..)
Expand Down
22 changes: 17 additions & 5 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl SsaTape {
let mut slot_count = 0;

// Get either a node or constant index
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
enum Slot {
Reg(u32),
Immediate(f32),
Expand Down Expand Up @@ -78,12 +78,22 @@ impl SsaTape {
let mut seen = HashSet::new();
let mut todo = vec![root];
let mut choice_count = 0;

let mut tape = vec![];
match mapping[&root] {
Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, 0)),
Slot::Immediate(imm) => {
tape.push(SsaOp::Output(0, 0));
tape.push(SsaOp::CopyImm(0, imm));
}
}

while let Some(node) = todo.pop() {
if *parent_count.get(&node).unwrap_or(&0) > 0 || !seen.insert(node)
{
continue;
}

let op = ctx.get_op(node).unwrap();
for child in op.iter_children() {
todo.push(child);
Expand Down Expand Up @@ -231,7 +241,6 @@ impl SsaTape {
let c = ctx.get_const(root).unwrap() as f32;
tape.push(SsaOp::CopyImm(0, c));
}

Ok((SsaTape { tape, choice_count }, vars))
}

Expand Down Expand Up @@ -261,8 +270,11 @@ impl SsaTape {
pub fn pretty_print(&self) {
for &op in self.tape.iter().rev() {
match op {
SsaOp::Output(arg, i) => {
println!("OUTPUT[{i}] = ${arg}");
}
SsaOp::Input(out, i) => {
println!("${out} = INPUT {i}");
println!("${out} = INPUT[{i}]");
}
SsaOp::NegReg(out, arg)
| SsaOp::AbsReg(out, arg)
Expand Down Expand Up @@ -405,7 +417,7 @@ mod test {
let c9 = ctx.max(c8, c6).unwrap();

let (tape, vs) = SsaTape::new(&ctx, c9).unwrap();
assert_eq!(tape.len(), 8);
assert_eq!(tape.len(), 9);
assert_eq!(vs.len(), 2);
}

Expand All @@ -416,7 +428,7 @@ mod test {
let x_squared = ctx.mul(x, x).unwrap();

let (tape, vs) = SsaTape::new(&ctx, x_squared).unwrap();
assert_eq!(tape.len(), 2);
assert_eq!(tape.len(), 3); // x, square, output
assert_eq!(vs.len(), 1);
}
}
4 changes: 2 additions & 2 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ mod test {
let c9 = ctx.max(c8, c6).unwrap();

let tape = VmData::<255>::new(&ctx, c9).unwrap();
assert_eq!(tape.len(), 8);
assert_eq!(tape.len(), 9);
assert_eq!(tape.vars.len(), 2);
}

Expand All @@ -1578,7 +1578,7 @@ mod test {
let x_squared = ctx.mul(x, x).unwrap();

let tape = VmData::<255>::new(&ctx, x_squared).unwrap();
assert_eq!(tape.len(), 2);
assert_eq!(tape.len(), 3); // x, square, output
assert_eq!(tape.vars.len(), 1);
}

Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/eval/test/point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ where
&mut Default::default(),
)
.unwrap();
assert_eq!(next.size(), 1);
assert_eq!(next.size(), 2); // constant, output

let tape = next.point_tape(Default::default());
assert_eq!(eval.eval(&tape, &[2.0]).unwrap().0, 1.5);
Expand Down
20 changes: 12 additions & 8 deletions fidget/src/core/vm/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use std::sync::Arc;
/// let mut ctx = Context::new();
/// let sum = ctx.import(&tree);
/// let data = VmData::<255>::new(&ctx, sum)?;
/// assert_eq!(data.len(), 3); // X, Y, and (X + Y)
/// assert_eq!(data.len(), 4); // X, Y, (X + Y), and output
///
/// let mut iter = data.iter_asm();
/// let vars = &data.vars; // map from var to index
Expand Down Expand Up @@ -132,18 +132,21 @@ impl<const N: usize> VmData<N> {

let mut choice_count = 0;

// The tape is constructed so that the output slot is first
assert_eq!(self.ssa.tape[0].output(), 0);
workspace.set_active(self.ssa.tape[0].output(), 0);
workspace.count += 1;

// Other iterators to consume various arrays in order
let mut choice_iter = choices.iter().rev();

let mut ops_out = tape.ssa.tape;

for mut op in self.ssa.tape.iter().cloned() {
let index = op.output();
let index = match &mut op {
SsaOp::Output(reg, _i) => {
*reg = workspace.get_or_insert_active(*reg);
workspace.alloc.op(op);
ops_out.push(op);
continue;
}
_ => op.output().unwrap(),
};

if workspace.active(index).is_none() {
if op.has_choice() {
Expand All @@ -158,6 +161,7 @@ impl<const N: usize> VmData<N> {
let new_index = workspace.active(index).unwrap();

match &mut op {
SsaOp::Output(..) => unreachable!(),
SsaOp::Input(index, ..) | SsaOp::CopyImm(index, ..) => {
*index = new_index;
}
Expand Down Expand Up @@ -286,7 +290,7 @@ impl<const N: usize> VmData<N> {
ops_out.push(op);
}

assert_eq!(workspace.count as usize, ops_out.len());
assert_eq!(workspace.count as usize + 1, ops_out.len());
let asm_tape = workspace.alloc.finalize();

Ok(VmData {
Expand Down
32 changes: 28 additions & 4 deletions fidget/src/core/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,14 @@ impl<const N: usize> TracingEvaluator for VmIntervalEval<N> {
let mut simplify = false;
let mut v = SlotArray(&mut self.0.slots);
let mut choices = self.0.choices.as_mut_slice().iter_mut();
let mut out = None;
for op in tape.iter_asm() {
match op {
RegOp::Output(arg, i) => {
assert_eq!(i, 0);
assert!(out.is_none());
out = Some(v[arg]);
}
RegOp::Input(out, i) => {
v[out] = vars[i as usize];
}
Expand Down Expand Up @@ -470,7 +476,7 @@ impl<const N: usize> TracingEvaluator for VmIntervalEval<N> {
}
}
Ok((
self.0.slots[0],
out.unwrap(),
if simplify {
Some(&self.0.choices)
} else {
Expand Down Expand Up @@ -501,8 +507,14 @@ impl<const N: usize> TracingEvaluator for VmPointEval<N> {
let mut choices = self.0.choices.as_mut_slice().iter_mut();
let mut simplify = false;
let mut v = SlotArray(&mut self.0.slots);
let mut out = None;
for op in tape.iter_asm() {
match op {
RegOp::Output(arg, i) => {
assert_eq!(i, 0);
assert!(out.is_none());
out = Some(v[arg]);
}
RegOp::Input(out, i) => {
v[out] = vars[i as usize];
}
Expand Down Expand Up @@ -765,7 +777,7 @@ impl<const N: usize> TracingEvaluator for VmPointEval<N> {
}
}
Ok((
self.0.slots[0],
out.unwrap(),
if simplify {
Some(&self.0.choices)
} else {
Expand All @@ -782,6 +794,9 @@ impl<const N: usize> TracingEvaluator for VmPointEval<N> {
struct BulkVmEval<T> {
/// Workspace for data
slots: Vec<Vec<T>>,

/// Output array
out: Vec<T>,
}

impl<T: From<f32> + Clone> BulkVmEval<T> {
Expand All @@ -792,6 +807,7 @@ impl<T: From<f32> + Clone> BulkVmEval<T> {
for s in self.slots.iter_mut() {
s.resize(size, f32::NAN.into());
}
self.out.resize(size, f32::NAN.into());
}
}

Expand All @@ -817,6 +833,10 @@ impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
let mut v = SlotArray(&mut self.0.slots);
for op in tape.iter_asm() {
match op {
RegOp::Output(arg, i) => {
assert_eq!(i, 0);
self.0.out[0..size].copy_from_slice(&v[arg][0..size]);
}
RegOp::Input(out, i) => {
v[out][0..size].copy_from_slice(&vars[i as usize]);
}
Expand Down Expand Up @@ -1100,7 +1120,7 @@ impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
}
}
}
Ok(&self.0.slots[0])
Ok(&self.0.out[0..size])
}
}

Expand All @@ -1125,6 +1145,10 @@ impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
let mut v = SlotArray(&mut self.0.slots);
for op in tape.iter_asm() {
match op {
RegOp::Output(arg, i) => {
assert_eq!(i, 0);
self.0.out[0..size].copy_from_slice(&v[arg][0..size]);
}
RegOp::Input(out, i) => {
v[out][0..size].copy_from_slice(&vars[i as usize]);
}
Expand Down Expand Up @@ -1428,7 +1452,7 @@ impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
}
}
}
Ok(&self.0.slots[0])
Ok(&self.0.out[0..size])
}
}

Expand Down
19 changes: 12 additions & 7 deletions fidget/src/jit/aarch64/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ impl Assembler for FloatSliceAssembler {
; ldr Q(reg(out_reg)), [x4]
);
}

fn build_output(&mut self, arg_reg: u8, out_index: u32) {
assert_eq!(out_index, 0);
dynasm!(self.0.ops
; add x4, x1, x3 // apply array offset
; str Q(reg(arg_reg)), [x4] // write to the output array
);
}

fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_sin(f: f32) -> f32 {
f.sin()
Expand Down Expand Up @@ -400,19 +409,15 @@ impl Assembler for FloatSliceAssembler {
IMM_REG.wrapping_sub(OFFSET)
}

fn finalize(mut self, out_reg: u8) -> Result<Mmap, Error> {
fn finalize(mut self) -> Result<Mmap, Error> {
dynasm!(self.0.ops
// update our "items remaining" counter
; sub x2, x2, 4 // We handle 4 items at a time

// Adjust the array offset pointer
// Adjust the input array offset amount
; add x3, x3, 16

// Prepare our return value, writing to the pointer in x1
// It's fine to overwrite X at this point in V0, since we're not
// using it anymore.
; mov v0.d[0], V(reg(out_reg)).d[1]
; stp D(reg(out_reg)), d0, [x1], 16
// Keep looping!
; b ->L

; ->E:
Expand Down
13 changes: 10 additions & 3 deletions fidget/src/jit/aarch64/grad_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ impl Assembler for GradSliceAssembler {
; ldr Q(reg(out_reg)), [x4]
);
}

fn build_output(&mut self, arg_reg: u8, out_index: u32) {
assert_eq!(out_index, 0);
dynasm!(self.0.ops
; add x4, x1, x3 // apply array offset
; str Q(reg(arg_reg)), [x4] // write to the output array
);
}

fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn grad_sin(v: Grad) -> Grad {
v.sin()
Expand Down Expand Up @@ -486,16 +495,14 @@ impl Assembler for GradSliceAssembler {
IMM_REG.wrapping_sub(OFFSET)
}

fn finalize(mut self, out_reg: u8) -> Result<Mmap, Error> {
fn finalize(mut self) -> Result<Mmap, Error> {
dynasm!(self.0.ops
// update our "items remaining" counter
; sub x2, x2, 1 // We handle 1 item at a time

// Adjust the array offset pointer
; add x3, x3, 16 // 1 item = 16 bytes

// Prepare our return value, writing to the pointer in x1
; str Q(reg(out_reg)), [x1], 16
; b ->L // Jump back to the loop start

; ->E:
Expand Down
Loading

0 comments on commit 2207b56

Please sign in to comment.