Skip to content

Commit

Permalink
feat(compression): implement tensor decompression in op depthwise conv
Browse files Browse the repository at this point in the history
Implement tensor decompression in op depthwise conv. Extend tests
to validate operation on compressed tensors.

BUG=part of tensorflow#2636
  • Loading branch information
ddavis-2015 authored and rkuester committed Dec 16, 2024
1 parent f9fecab commit 2268ddb
Show file tree
Hide file tree
Showing 7 changed files with 581 additions and 28 deletions.
41 changes: 40 additions & 1 deletion tensorflow/lite/micro/kernels/depthwise_conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,16 +52,37 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
: nullptr;

#ifdef USE_TFLM_COMPRESSION

MicroContext* micro_context = GetMicroContext(context);

const CompressionTensorData* filter_comp_td =
micro_context->GetTensorCompressionData(node,
kDepthwiseConvWeightsTensor);
const CompressionTensorData* bias_comp_td =
micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);

#endif // USE_TFLM_COMPRESSION

switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: {
tflite::reference_ops::DepthwiseConv(
DepthwiseConvParamsFloat(params, data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<float>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
Expand Down Expand Up @@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
Expand All @@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
#ifdef USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
filter_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
Expand Down
23 changes: 21 additions & 2 deletions tensorflow/lite/micro/kernels/depthwise_conv_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv(

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
micro_context->DeallocateTempTfLiteTensor(bias);
if (has_bias) {
micro_context->DeallocateTempTfLiteTensor(bias);
}
micro_context->DeallocateTempTfLiteTensor(output);

return kTfLiteOk;
Expand Down Expand Up @@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));

#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) &&
filter->type == kTfLiteInt4) {
MicroPrintf("Compression not supported with INT4 tensors");
return kTfLiteError;
}
data->weights_scratch_index =
micro_context->AllocateDecompressionScratchBuffer(
node, kDepthwiseConvWeightsTensor);
data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
node, kDepthwiseConvBiasTensor);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
Expand Down
Loading

0 comments on commit 2268ddb

Please sign in to comment.