diff --git a/kn-cuda-eval/src/planner.rs b/kn-cuda-eval/src/planner.rs index 81414b1..ece59c7 100644 --- a/kn-cuda-eval/src/planner.rs +++ b/kn-cuda-eval/src/planner.rs @@ -17,6 +17,7 @@ use kn_graph::dtype::{DisplayCFloat, DScalar, DType}; use kn_graph::graph::{BinaryOp, Graph, Operation, SliceRange, UnaryOp, Value}; use kn_graph::optimizer::recurse::heap_recurse; use kn_graph::shape::{ConcreteShape, Size}; +use kn_graph::wrap_debug::WrapDebug; use crate::autokernel::gather::GatherKernel; use crate::autokernel::layernorm::LayernormKernel; @@ -327,7 +328,7 @@ impl<'a> Planner<'a> { let result: PlanTensor = match &result_info.operation { &Operation::Input { index: _ } => self.alloc_tensor_shared(result_shape, result_dtype, Some(value)), - Operation::Constant { tensor } => { + Operation::Constant { tensor: WrapDebug(tensor) } => { let result = self.alloc_tensor_dedicated(result_shape, tensor.dtype()); // copy values diff --git a/kn-graph/src/cpu.rs b/kn-graph/src/cpu.rs index 692a243..efca81d 100644 --- a/kn-graph/src/cpu.rs +++ b/kn-graph/src/cpu.rs @@ -5,16 +5,17 @@ use std::time::Instant; use indexmap::IndexMap; use itertools::Itertools; use ndarray::{ - s, ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, SliceInfo, + ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, s, SliceInfo, SliceInfoElem, Zip, }; use crate::dtype::{ - dispatch_dtensor, dispatch_dtype, map_dtensor, map_dtensor_pair, DTensor, DType, IntoDScalar, Tensor, + dispatch_dtensor, dispatch_dtype, DTensor, DType, IntoDScalar, map_dtensor, map_dtensor_pair, Tensor, }; use crate::graph::{ConvDetails, Graph, Operation, SliceRange, Value, ValueInfo}; use crate::ndarray::{Array, ArrayBase, Axis}; use crate::shape::ConcreteShape; +use crate::wrap_debug::WrapDebug; pub fn cpu_eval_graph(graph: &Graph, batch_size: usize, inputs: &[DTensor]) -> Vec { let exec = cpu_eval_graph_exec(graph, batch_size, inputs, false); @@ -131,7 +132,7 @@ fn try_run_cpu_operation( let result: DTensor = match info.operation { Operation::Input { index } => input(index)?, - Operation::Constant { ref tensor } => tensor.clone(), + Operation::Constant { tensor: WrapDebug(ref tensor) } => tensor.clone(), Operation::View { input } => { let input = map(input)?; input.reshape(output_shape_dyn) diff --git a/kn-graph/src/graph.rs b/kn-graph/src/graph.rs index b0f13ce..7bf98e8 100644 --- a/kn-graph/src/graph.rs +++ b/kn-graph/src/graph.rs @@ -5,15 +5,16 @@ use std::fmt::{Debug, Display, Formatter}; use std::ops::Index; use decorum::Total; -use itertools::{zip_eq, Itertools}; +use itertools::{Itertools, zip_eq}; use ndarray::{ArrayView, IxDyn}; use rand::random; -use crate::cpu::{run_cpu_const_operation, OperationError, OperationResult}; -use crate::dtype::{dispatch_dtensor, dispatch_dtype, map_dscalar_pair, DScalar, DTensor, DType, IntoDScalar, Tensor}; +use crate::cpu::{OperationError, OperationResult, run_cpu_const_operation}; +use crate::dtype::{dispatch_dtensor, dispatch_dtype, DScalar, DTensor, DType, IntoDScalar, map_dscalar_pair, Tensor}; use crate::optimizer::recurse::heap_recurse; use crate::shape; use crate::shape::{Shape, Size}; +use crate::wrap_debug::WrapDebug; /// The core graph datastructure. /// @@ -106,7 +107,7 @@ pub enum Operation { /// A runtime-variable input. Input { index: usize }, /// A constant built into the network. - Constant { tensor: DTensor }, + Constant { tensor: WrapDebug }, //TODO maybe fuse a bunch of these operations into a single "Restride" operation? /// View a value as a different shape. @@ -490,7 +491,7 @@ impl Graph { match info.operation { Operation::Input { .. } => None, - Operation::Constant { ref tensor } => dispatch_dtensor!(tensor, |_T, _f, tensor| { + Operation::Constant { tensor: WrapDebug(ref tensor) } => dispatch_dtensor!(tensor, |_T, _f, tensor| { let &e = tensor.iter().next()?; tensor.iter().all(|&d| d == e).then(|| e.to_dscalar()) }), @@ -590,7 +591,7 @@ impl Graph { #[must_use] pub fn constant_tensor(&mut self, tensor: DTensor) -> Value { let shape = Shape::fixed(tensor.shape()); - self.push(shape, tensor.dtype(), Operation::Constant { tensor }) + self.push(shape, tensor.dtype(), Operation::Constant { tensor: WrapDebug(tensor) }) } #[must_use]