Skip to content

Commit

Permalink
Add/custom registerstatistics (#9)
Browse files Browse the repository at this point in the history
* Add custom tf.python.framework.ops.registerstatistics 
* Support global max pooling: add registerstatistics for Max op
* support batch normalization: add registerstatistics for FusedBatchNormV3 op
  • Loading branch information
tokusumi authored Aug 16, 2020
1 parent cf6d828 commit bb85ceb
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 59 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Support `tf.keras.layers` as follows,
| Pooling | AveragePooling[1D/2D] |
| | GlobalAveragePooling[1D/2D/3D]|
| | MaxPooling[1D/2D] |
| | GlobalMaxPool[1D/2D/3D] |
| Normalization | BatchNormalization |
| Activation | Softmax |
| Attention | Attention |
| | AdditiveAttention |
Expand All @@ -72,10 +74,8 @@ Not support `tf.keras.layers` as follows. They are calculated as zero or smaller
| Conv | Conv3DTranspose |
| Pooling | AveragePooling3D |
| | MaxPooling3D |
| | GlobalMaxPool[1D/2D/3D] |
| | UpSampling[1D/2D/3D] |
| Normalization | BatchNormalization |
| | LayerNormalization |
| Normalization | LayerNormalization |
| RNN | SimpleRNN |
| | LSTM |
| | GRU |
Expand Down
3 changes: 3 additions & 0 deletions keras_flops/flops_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from tensorflow.keras import Sequential, Model

import keras_flops.flops_registory


def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None) -> int:
"""
Expand Down Expand Up @@ -35,5 +37,6 @@ def get_flops(model: Union[Model, Sequential], batch_size: Optional[int] = None)
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
)
# print(frozen_func.graph.get_operations())
# TODO: show each FLOPS
return flops.total_float_ops
33 changes: 33 additions & 0 deletions keras_flops/flops_registory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
from tensorflow.python.profiler.internal.flops_registry import _reduction_op_flops


@ops.RegisterStatistics("FusedBatchNormV3", "flops")
def _flops_fused_batch_norm_v3(graph, node):
"""inference is only supportted"""
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
in_shape.assert_is_fully_defined()
mean_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[3])
mean_shape.assert_is_fully_defined()
variance_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[4])
variance_shape.assert_is_fully_defined()

if node.attr["is_training"].b is True:
raise ValueError("Only supports inference mode")

num_flops = (
in_shape.num_elements()
+ 4 * variance_shape.num_elements()
+ mean_shape.num_elements()
)
return ops.OpStats("flops", num_flops)


@ops.RegisterStatistics("Max", "flops")
def _flops_max(graph, node):
"""inference is supportted"""
# reduction - comparison, no finalization
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)

47 changes: 42 additions & 5 deletions tests/test_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
GlobalAveragePooling3D,
MaxPooling1D,
MaxPooling2D,
GlobalMaxPooling1D,
GlobalMaxPooling2D,
GlobalMaxPooling3D,
BatchNormalization,
AdditiveAttention,
Attention,
Expand Down Expand Up @@ -235,6 +238,29 @@ def test_maxpooling1d2d3d():
assert flops == in_w * in_h * kernel


def test_global_maxpooling1d2d3d():
"""
reduct rest (Ndim) of target axis.
compare Ndim - 1 ops.
"""
in_w = 32
in_h = 32
in_z = 32
kernel = 3

model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == (in_w - 1) * kernel

model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == (in_w * in_h - 1) * kernel

model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == (in_w * in_h * in_z - 1) * kernel


def test_softmax():
kernel = 8
model = Sequential(Activation("softmax", input_shape=(kernel,)))
Expand Down Expand Up @@ -293,10 +319,7 @@ def test_batchnormalization():
2. (1 ops * |var|) inv *= gamma (scale)
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
, where |var| = |mean| = channel size in default
Thus, 5 * channel size + input element size.
NOTE: support only fused=False
Use gen_nn_ops.fused_batch_norm_v3 but this is not registered yet and calculated as zero.
Thus, tot FLOPs = 5 * channel size + input element size.
"""
in_w = 32
in_h = 32
Expand All @@ -310,7 +333,21 @@ def test_batchnormalization():
)
)
flops = get_flops(model, batch_size=1)
assert flops == 5 * in_ch + in_w * in_ch, "fused is False"
assert (
flops == 5 * in_ch + in_w * in_ch
), "fused is False. see nn_impl.batch_normalization"

model = Sequential(
BatchNormalization(
beta_initializer="ones",
gamma_initializer="ones",
input_shape=(in_w, in_h, in_ch),
)
)
flops = get_flops(model, batch_size=1)
assert (
flops == 5 * in_ch + in_w * in_h * in_ch
), "fused is True, see gen_nn.fused_batch_norm_v3"


def test_additive_attention():
Expand Down
64 changes: 13 additions & 51 deletions tests/test_not_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
Conv3DTranspose,
AveragePooling3D,
MaxPooling3D,
GlobalMaxPooling1D,
GlobalMaxPooling2D,
GlobalMaxPooling3D,
UpSampling1D,
UpSampling2D,
UpSampling3D,
BatchNormalization,
LayerNormalization,
)
from keras_flops import get_flops
Expand Down Expand Up @@ -135,26 +131,6 @@ def test_maxpooling1d2d3d():
assert flops == in_w * in_h * in_z * kernel


@pytest.mark.xfail
def test_global_maxpooling1d2d3d():
in_w = 32
in_h = 32
in_z = 32
kernel = 32

model = Sequential(GlobalMaxPooling1D(input_shape=(in_w, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == in_w * kernel

model = Sequential(GlobalMaxPooling2D(input_shape=(in_w, in_h, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == in_w * in_h * kernel

model = Sequential(GlobalMaxPooling3D(input_shape=(in_w, in_h, in_z, kernel)))
flops = get_flops(model, batch_size=1)
assert flops == in_w * in_h * in_z * kernel


@pytest.mark.xfail
def test_upsampling1d2d3d():
in_w = 32
Expand Down Expand Up @@ -182,28 +158,6 @@ def test_upsampling1d2d3d():
assert flops == in_w * in_h * in_z * kernel


@pytest.mark.xfail
def test_batchnormalization():
"""
batch normalization in tf uses gen_nn_ops.fused_batch_norm_v3 if input shape are 4D
"""
in_w = 32
in_h = 32
in_ch = 3

model = Sequential(
BatchNormalization(
beta_initializer="ones",
gamma_initializer="ones",
input_shape=(in_w, in_h, in_ch),
)
)
flops = get_flops(model, batch_size=1)
assert (
flops == 5 * in_ch + in_w * in_h * in_ch
), "fused is True, fused_batch_norm_v3 is not supportted"


@pytest.mark.xfail
def test_layernormalization():
"""
Expand All @@ -213,11 +167,12 @@ def test_layernormalization():
2. (1 ops * |var|) inv *= gamma (scale)
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
, where |var| = |mean| = 1 in default
Thus, 5 channel size + input element size.
Thus, 5 + input element size.
Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation,
but gen_nn_ops.fused_batch_norm_v3 is not registered yet, so can not evaluate corrent FLOPs.
Use nn.fused_batch_norm (gen_nn_ops.fused_batch_norm_v3) for layer normalization, above calculation.
gen_nn_ops.fused_batch_norm_v3 support only 4D, so reshape data as 4D and input them.
squeezed_shape (ndim ops), scale (|x| ops) and shift (not float ops) is calculated.
NOTE: is_training = True, if make trainable attributes of tf.keras.Model instanse False. So, statistics will be incorrect.
"""
in_w = 32
in_h = 32
Expand All @@ -244,6 +199,13 @@ def test_layernormalization():
)
)
flops = get_flops(model, batch_size=1)
assert flops == len(input_shape) + 1 + in_w * in_h * in_ch, "fused is True"
assert (
flops
== len(input_shape)
+ 1
+ 5
+ in_w * in_h * in_ch
+ 5 * in_ch
+ in_w * in_h * in_ch
), "fused is True. check gen_nn_ops.fused_batch_norm_v3"

assert flops == len(input_shape) + 1 + 5 * in_ch + in_w * in_h * in_ch

0 comments on commit bb85ceb

Please sign in to comment.