Skip to content

Commit

Permalink
Use WrapDebug for graph constants again.
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelPeeters committed Dec 9, 2023
1 parent 92036ec commit 5f727e5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
3 changes: 2 additions & 1 deletion kn-cuda-eval/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use kn_graph::dtype::{DisplayCFloat, DScalar, DType};
use kn_graph::graph::{BinaryOp, Graph, Operation, SliceRange, UnaryOp, Value};
use kn_graph::optimizer::recurse::heap_recurse;
use kn_graph::shape::{ConcreteShape, Size};
use kn_graph::wrap_debug::WrapDebug;

use crate::autokernel::gather::GatherKernel;
use crate::autokernel::layernorm::LayernormKernel;
Expand Down Expand Up @@ -327,7 +328,7 @@ impl<'a> Planner<'a> {

let result: PlanTensor = match &result_info.operation {
&Operation::Input { index: _ } => self.alloc_tensor_shared(result_shape, result_dtype, Some(value)),
Operation::Constant { tensor } => {
Operation::Constant { tensor: WrapDebug(tensor) } => {
let result = self.alloc_tensor_dedicated(result_shape, tensor.dtype());

// copy values
Expand Down
7 changes: 4 additions & 3 deletions kn-graph/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ use std::time::Instant;
use indexmap::IndexMap;
use itertools::Itertools;
use ndarray::{
s, ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, SliceInfo,
ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, s, SliceInfo,
SliceInfoElem, Zip,
};

use crate::dtype::{
dispatch_dtensor, dispatch_dtype, map_dtensor, map_dtensor_pair, DTensor, DType, IntoDScalar, Tensor,
dispatch_dtensor, dispatch_dtype, DTensor, DType, IntoDScalar, map_dtensor, map_dtensor_pair, Tensor,
};
use crate::graph::{ConvDetails, Graph, Operation, SliceRange, Value, ValueInfo};
use crate::ndarray::{Array, ArrayBase, Axis};
use crate::shape::ConcreteShape;
use crate::wrap_debug::WrapDebug;

pub fn cpu_eval_graph(graph: &Graph, batch_size: usize, inputs: &[DTensor]) -> Vec<DTensor> {
let exec = cpu_eval_graph_exec(graph, batch_size, inputs, false);
Expand Down Expand Up @@ -131,7 +132,7 @@ fn try_run_cpu_operation(

let result: DTensor = match info.operation {
Operation::Input { index } => input(index)?,
Operation::Constant { ref tensor } => tensor.clone(),
Operation::Constant { tensor: WrapDebug(ref tensor) } => tensor.clone(),
Operation::View { input } => {
let input = map(input)?;
input.reshape(output_shape_dyn)
Expand Down
13 changes: 7 additions & 6 deletions kn-graph/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ use std::fmt::{Debug, Display, Formatter};
use std::ops::Index;

use decorum::Total;
use itertools::{zip_eq, Itertools};
use itertools::{Itertools, zip_eq};
use ndarray::{ArrayView, IxDyn};
use rand::random;

use crate::cpu::{run_cpu_const_operation, OperationError, OperationResult};
use crate::dtype::{dispatch_dtensor, dispatch_dtype, map_dscalar_pair, DScalar, DTensor, DType, IntoDScalar, Tensor};
use crate::cpu::{OperationError, OperationResult, run_cpu_const_operation};
use crate::dtype::{dispatch_dtensor, dispatch_dtype, DScalar, DTensor, DType, IntoDScalar, map_dscalar_pair, Tensor};
use crate::optimizer::recurse::heap_recurse;
use crate::shape;
use crate::shape::{Shape, Size};
use crate::wrap_debug::WrapDebug;

/// The core graph datastructure.
///
Expand Down Expand Up @@ -106,7 +107,7 @@ pub enum Operation {
/// A runtime-variable input.
Input { index: usize },
/// A constant built into the network.
Constant { tensor: DTensor },
Constant { tensor: WrapDebug<DTensor> },

//TODO maybe fuse a bunch of these operations into a single "Restride" operation?
/// View a value as a different shape.
Expand Down Expand Up @@ -490,7 +491,7 @@ impl Graph {

match info.operation {
Operation::Input { .. } => None,
Operation::Constant { ref tensor } => dispatch_dtensor!(tensor, |_T, _f, tensor| {
Operation::Constant { tensor: WrapDebug(ref tensor) } => dispatch_dtensor!(tensor, |_T, _f, tensor| {
let &e = tensor.iter().next()?;
tensor.iter().all(|&d| d == e).then(|| e.to_dscalar())
}),
Expand Down Expand Up @@ -590,7 +591,7 @@ impl Graph {
#[must_use]
pub fn constant_tensor(&mut self, tensor: DTensor) -> Value {
let shape = Shape::fixed(tensor.shape());
self.push(shape, tensor.dtype(), Operation::Constant { tensor })
self.push(shape, tensor.dtype(), Operation::Constant { tensor: WrapDebug(tensor) })
}

#[must_use]
Expand Down

0 comments on commit 5f727e5

Please sign in to comment.