Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/gather import #1843

Merged
merged 11 commits into from
Jun 3, 2024
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ represent the corresponding Burn Op.
| [EyeLike][55] | ❌ | ❌ |
| [Flatten][56] | ✅ | ✅ |
| [Floor][57] | ❌ | ❌ |
| [Gather][58] | ❌ | ✅ |
| [Gather][58] | ✅ | ✅ |
| [GatherElements][59] | ✅ | ✅ |
| [GatherND][60] | ❌ | ❌ |
| [Gelu][61] | ✅ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/exp/exp.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/layer_norm/layer_norm.onnx")
Expand Down
20 changes: 10 additions & 10 deletions crates/burn-import/onnx-tests/tests/gather/gather.onnx
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
pytorch2.2.2:Ø
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
pytorch2.1.1:¤
A
onnx::Gather_0
onnx::Gather_12/Gather"Gather*
axis 
main_graphZ(
onnx::GatherElements_0
main_graphZ
onnx::Gather_0


Z(
onnx::GatherElements_1


Z
onnx::Gather_1


b
2

Expand Down
16 changes: 8 additions & 8 deletions crates/burn-import/onnx-tests/tests/gather/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x
gathered = torch.index_select(x, 1, index)
return gathered


def main():
Expand All @@ -24,19 +24,19 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "gather.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

dummy_input = torch.randn(2, 3, device=device)
dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = torch.tensor([0, 2], dtype=torch.int64)

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pytorch2.1.1:Ø
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
axis 
main_graphZ(
onnx::GatherElements_0


Z(
onnx::GatherElements_1


b
2


B
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather_elements.onnx
# note that the ONNX specification for `GatherElements` corresponds to PyTorch's/Burn's `gather` function

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather_elements.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
17 changes: 16 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include_models!(
expand,
flatten,
gather,
gather_elements,
gelu,
global_avr_pool,
layer_norm,
Expand Down Expand Up @@ -358,9 +359,23 @@ mod tests {

#[test]
fn gather() {
// Initialize the model with weights (loaded from the exported file)
let model: gather::Model<Backend> = gather::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
let output = model.forward(input, index);
let expected = Data::from([[1., 3.], [4., 6.]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_elements() {
// Initialize the model with weights (loaded from the exported file)
let model: gather_elements::Model<Backend> = gather_elements::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]], &device);
Expand Down
13 changes: 8 additions & 5 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use super::{
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -93,6 +93,7 @@ pub enum Node<PS: PrecisionSettings> {
Dropout(DropoutNode),
Expand(ExpandNode),
Gather(GatherNode),
GatherElements(GatherElementsNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Expand Down Expand Up @@ -128,6 +129,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Expand(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GatherElements(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Expand Down Expand Up @@ -173,6 +175,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Dropout(_) => "dropout",
Node::Expand(_) => "expand",
Node::Gather(_) => "gather",
Node::GatherElements(_) => "gather_elements",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-import/src/burn/node/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
let output = &self.output.name;

quote! {
let #output = #input.gather(#dim, #index);
let #output = #input.select(#dim, #index);
}
}

Expand All @@ -62,9 +62,9 @@ mod tests {

graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_int("tensor2", 1),
TensorType::new_float("tensor3", 2),
1,
0,
));

graph.register_input_output(
Expand Down Expand Up @@ -98,9 +98,9 @@ mod tests {
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
tensor2: Tensor<B, 1, Int>
) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);
let tensor3 = tensor1.select(0, tensor2);

tensor3
}
Expand Down
112 changes: 112 additions & 0 deletions crates/burn-import/src/burn/node/gather_elements.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct GatherElementsNode {
pub input: TensorType,
pub index: TensorType,
pub output: TensorType,
pub dim: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherElementsNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.index.clone()),
]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let index = scope.tensor_use_owned(&self.index, node_position);
let output = &self.output.name;

quote! {
let #output = #input.gather(#dim, #index);
}
}

fn into_node(self) -> super::Node<PS> {
Node::GatherElements(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{gather_elements::GatherElementsNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_gather_elements() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherElementsNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_float("tensor3", 2),
1,
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod expand;
pub(crate) mod gather;
pub(crate) mod gather_elements;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
pub(crate) mod linear;
Expand Down
Loading
Loading