From bb85ceb9a68c646d0f0144b8372388db2a2b912b Mon Sep 17 00:00:00 2001 From: "T. Tokusumi" <41147016+tokusumi@users.noreply.github.com> Date: Sun, 16 Aug 2020 21:21:05 +0900 Subject: [PATCH] Add/custom registerstatistics (#9) * Add custom tf.python.framework.ops.registerstatistics * Support global max pooling: add registerstatistics for Max op * support batch normalization: add registerstatistics for FusedBatchNormV3 op --- README.md | 6 +-- keras_flops/flops_calculation.py | 3 ++ keras_flops/flops_registory.py | 33 ++++++++++++++++ tests/test_flops.py | 47 ++++++++++++++++++++--- tests/test_not_support.py | 64 +++++++------------------------- 5 files changed, 94 insertions(+), 59 deletions(-) create mode 100644 keras_flops/flops_registory.py diff --git a/README.md b/README.md index 2142c6f..0eae6c3 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ Support `tf.keras.layers` as follows, | Pooling | AveragePooling[1D/2D] | | | GlobalAveragePooling[1D/2D/3D]| | | MaxPooling[1D/2D] | +| | GlobalMaxPool[1D/2D/3D] | +| Normalization | BatchNormalization | | Activation | Softmax | | Attention | Attention | | | AdditiveAttention | @@ -72,10 +74,8 @@ Not support `tf.keras.layers` as follows. They are calculated as zero or smaller | Conv | Conv3DTranspose | | Pooling | AveragePooling3D | | | MaxPooling3D | -| | GlobalMaxPool[1D/2D/3D] | | | UpSampling[1D/2D/3D] | -| Normalization | BatchNormalization | -| | LayerNormalization | +| Normalization | LayerNormalization | | RNN | SimpleRNN | | | LSTM | | | GRU | diff --git a/keras_flops/flops_calculation.py b/keras_flops/flops_calculation.py index 38e6a07..d91f693 100644 --- a/keras_flops/flops_calculation.py +++ b/keras_flops/flops_calculation.py @@ -6,6 +6,8 @@ from tensorflow.keras import Sequential, Model +import keras_flops.flops_registory + def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None) -> int: """ @@ -35,5 +37,6 @@ def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None) flops = tf.compat.v1.profiler.profile( graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts ) + # print(frozen_func.graph.get_operations()) # TODO: show each FLOPS return flops.total_float_ops diff --git a/keras_flops/flops_registory.py b/keras_flops/flops_registory.py new file mode 100644 index 0000000..075bf1a --- /dev/null +++ b/keras_flops/flops_registory.py @@ -0,0 +1,33 @@ +import numpy as np +from tensorflow.python.framework import ops +from tensorflow.python.framework import graph_util +from tensorflow.python.profiler.internal.flops_registry import _reduction_op_flops + + +@ops.RegisterStatistics("FusedBatchNormV3", "flops") +def _flops_fused_batch_norm_v3(graph, node): + """inference is only supportted""" + in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + in_shape.assert_is_fully_defined() + mean_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[3]) + mean_shape.assert_is_fully_defined() + variance_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[4]) + variance_shape.assert_is_fully_defined() + + if node.attr["is_training"].b is True: + raise ValueError("Only supports inference mode") + + num_flops = ( + in_shape.num_elements() + + 4 * variance_shape.num_elements() + + mean_shape.num_elements() + ) + return ops.OpStats("flops", num_flops) + + +@ops.RegisterStatistics("Max", "flops") +def _flops_max(graph, node): + """inference is supportted""" + # reduction - comparison, no finalization + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) + diff --git a/tests/test_flops.py b/tests/test_flops.py index c8762a8..b1d93bf 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -15,6 +15,9 @@ GlobalAveragePooling3D, MaxPooling1D, MaxPooling2D, + GlobalMaxPooling1D, + GlobalMaxPooling2D, + GlobalMaxPooling3D, BatchNormalization, AdditiveAttention, Attention, @@ -235,6 +238,29 @@ def test_maxpooling1d2d3d(): assert flops == in_w * in_h * kernel +def test_global_maxpooling1d2d3d(): + """ + reduct rest (Ndim) of target axis. + compare Ndim - 1 ops. + """ + in_w = 32 + in_h = 32 + in_z = 32 + kernel = 3 + + model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel))) + flops = get_flops(model, batch_size=1) + assert flops == (in_w - 1) * kernel + + model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel))) + flops = get_flops(model, batch_size=1) + assert flops == (in_w * in_h - 1) * kernel + + model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel))) + flops = get_flops(model, batch_size=1) + assert flops == (in_w * in_h * in_z - 1) * kernel + + def test_softmax(): kernel = 8 model = Sequential(Activation("softmax", input_shape=(kernel,))) @@ -293,10 +319,7 @@ def test_batchnormalization(): 2. (1 ops * |var|) inv *= gamma (scale) 3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv , where |var| = |mean| = channel size in default - Thus, 5 * channel size + input element size. - - NOTE: support only fused=False - Use gen_nn_ops.fused_batch_norm_v3 but this is not registered yet and calculated as zero. + Thus, tot FLOPs = 5 * channel size + input element size. """ in_w = 32 in_h = 32 @@ -310,7 +333,21 @@ def test_batchnormalization(): ) ) flops = get_flops(model, batch_size=1) - assert flops == 5 * in_ch + in_w * in_ch, "fused is False" + assert ( + flops == 5 * in_ch + in_w * in_ch + ), "fused is False. see nn_impl.batch_normalization" + + model = Sequential( + BatchNormalization( + beta_initializer="ones", + gamma_initializer="ones", + input_shape=(in_w, in_h, in_ch), + ) + ) + flops = get_flops(model, batch_size=1) + assert ( + flops == 5 * in_ch + in_w * in_h * in_ch + ), "fused is True, see gen_nn.fused_batch_norm_v3" def test_additive_attention(): diff --git a/tests/test_not_support.py b/tests/test_not_support.py index 73aa541..72401e2 100644 --- a/tests/test_not_support.py +++ b/tests/test_not_support.py @@ -8,13 +8,9 @@ Conv3DTranspose, AveragePooling3D, MaxPooling3D, - GlobalMaxPooling1D, - GlobalMaxPooling2D, - GlobalMaxPooling3D, UpSampling1D, UpSampling2D, UpSampling3D, - BatchNormalization, LayerNormalization, ) from keras_flops import get_flops @@ -135,26 +131,6 @@ def test_maxpooling1d2d3d(): assert flops == in_w * in_h * in_z * kernel -@pytest.mark.xfail -def test_global_maxpooling1d2d3d(): - in_w = 32 - in_h = 32 - in_z = 32 - kernel = 32 - - model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel))) - flops = get_flops(model, batch_size=1) - assert flops == in_w * kernel - - model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel))) - flops = get_flops(model, batch_size=1) - assert flops == in_w * in_h * kernel - - model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel))) - flops = get_flops(model, batch_size=1) - assert flops == in_w * in_h * in_z * kernel - - @pytest.mark.xfail def test_upsampling1d2d3d(): in_w = 32 @@ -182,28 +158,6 @@ def test_upsampling1d2d3d(): assert flops == in_w * in_h * in_z * kernel -@pytest.mark.xfail -def test_batchnormalization(): - """ - batch normalization in tf uses gen_nn_ops.fused_batch_norm_v3 if input shape are 4D - """ - in_w = 32 - in_h = 32 - in_ch = 3 - - model = Sequential( - BatchNormalization( - beta_initializer="ones", - gamma_initializer="ones", - input_shape=(in_w, in_h, in_ch), - ) - ) - flops = get_flops(model, batch_size=1) - assert ( - flops == 5 * in_ch + in_w * in_h * in_ch - ), "fused is True, fused_batch_norm_v3 is not supportted" - - @pytest.mark.xfail def test_layernormalization(): """ @@ -213,11 +167,12 @@ def test_layernormalization(): 2. (1 ops * |var|) inv *= gamma (scale) 3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv , where |var| = |mean| = 1 in default - Thus, 5 channel size + input element size. + Thus, 5 + input element size. - Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation, - but gen_nn_ops.fused_batch_norm_v3 is not registered yet, so can not evaluate corrent FLOPs. + Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation. + gen_nn_ops.fused_batch_norm_v3 support only 4D, so reshape data as 4D and input them. squeezed_shape (ndim ops), scale (|x| ops) and shift (not float ops) is calculated. + NOTE: is_training = True, if make trainable attributes of tf.keras.Model instanse False. So, statistics will be incorrect. """ in_w = 32 in_h = 32 @@ -244,6 +199,13 @@ def test_layernormalization(): ) ) flops = get_flops(model, batch_size=1) - assert flops == len(input_shape) + 1 + in_w * in_h * in_ch, "fused is True" + assert ( + flops + == len(input_shape) + + 1 + + 5 + + in_w * in_h * in_ch + + 5 * in_ch + + in_w * in_h * in_ch + ), "fused is True. check gen_nn_ops.fused_batch_norm_v3" - assert flops == len(input_shape) + 1 + 5 * in_ch + in_w * in_h * in_ch