-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
165 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
..._recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import tensorflow as tf | ||
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization | ||
|
||
|
||
class LayerNormalization(TFLayerNormalization): | ||
|
||
def call(self, inputs): | ||
# TODO(b/229545225): Remove the RaggedTensor check. | ||
is_ragged = isinstance(inputs, tf.RaggedTensor) | ||
if is_ragged: | ||
inputs_lengths = inputs.nested_row_lengths() | ||
inputs = inputs.to_tensor() | ||
inputs = tf.cast(inputs, self.compute_dtype) | ||
# Compute the axes along which to reduce the mean / variance | ||
input_shape = tf.shape(inputs) | ||
# Get the number of dimensions dynamically | ||
ndims = input_shape.shape[0] | ||
|
||
# Broadcasting only necessary for norm when the axis is not just | ||
# the last dimension | ||
broadcast_shape = [1] * ndims | ||
for dim in self.axis: | ||
broadcast_shape[dim] = input_shape[dim] | ||
|
||
def _broadcast(v): | ||
if v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]: | ||
return tf.reshape(v, broadcast_shape) | ||
return v | ||
|
||
if not self._fused: | ||
input_dtype = inputs.dtype | ||
if input_dtype in ("float16", "bfloat16") and self.dtype == "float32": | ||
# If mixed precision is used, cast inputs to float32 so that | ||
# this is at least as numerically stable as the fused version. | ||
inputs = tf.cast(inputs, "float32") | ||
|
||
# Calculate the moments on the last axis (layer activations). | ||
mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True) | ||
|
||
scale, offset = _broadcast(self.gamma), _broadcast(self.beta) | ||
|
||
# Compute layer normalization using the batch_normalization | ||
# function. | ||
outputs = tf.nn.batch_normalization( | ||
inputs, | ||
mean, | ||
variance, | ||
offset=offset, | ||
scale=scale, | ||
variance_epsilon=self.epsilon, | ||
) | ||
outputs = tf.cast(outputs, input_dtype) | ||
else: | ||
# Collapse dims before self.axis, and dims in self.axis | ||
|
||
axis = sorted(self.axis) | ||
tensor_shape = tf.shape(inputs) | ||
pre_dim = tf.reduce_prod(tensor_shape[:axis[0]]) | ||
in_dim = tf.reduce_prod(tensor_shape[axis[0]:]) | ||
squeezed_shape = [1, pre_dim, in_dim, 1] | ||
# This fused operation requires reshaped inputs to be NCHW. | ||
data_format = "NCHW" | ||
|
||
inputs = tf.reshape(inputs, squeezed_shape) | ||
|
||
# self.gamma and self.beta have the wrong shape for | ||
# fused_batch_norm, so we cannot pass them as the scale and offset | ||
# parameters. Therefore, we create two constant tensors in correct | ||
# shapes for fused_batch_norm and later construct a separate | ||
# calculation on the scale and offset. | ||
scale = tf.ones([pre_dim], dtype=self.dtype) | ||
offset = tf.zeros([pre_dim], dtype=self.dtype) | ||
|
||
# Compute layer normalization using the fused_batch_norm function. | ||
outputs, _, _ = tf.compat.v1.nn.fused_batch_norm( | ||
inputs, | ||
scale=scale, | ||
offset=offset, | ||
epsilon=self.epsilon, | ||
data_format=data_format, | ||
) | ||
|
||
outputs = tf.reshape(outputs, tensor_shape) | ||
|
||
scale, offset = _broadcast(self.gamma), _broadcast(self.beta) | ||
|
||
if scale is not None: | ||
outputs = outputs * tf.cast(scale, outputs.dtype) | ||
if offset is not None: | ||
outputs = outputs + tf.cast(offset, outputs.dtype) | ||
|
||
# If some components of the shape got lost due to adjustments, fix that. | ||
outputs = tf.reshape(outputs, input_shape) | ||
|
||
if is_ragged: | ||
outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths) | ||
return outputs |
59 changes: 59 additions & 0 deletions
59
...mmenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
|
||
import tensorflow as tf | ||
from tensorflow_recommenders_addons import dynamic_embedding as de | ||
|
||
|
||
class DynamicLayerNormalizationTest(tf.test.TestCase): | ||
|
||
def test_dynamic_shape_support(self): | ||
input_data = tf.keras.Input(shape=(None, 10), dtype=tf.float32) | ||
layer = de.keras.layers.LayerNormalization() | ||
output = layer(input_data) | ||
|
||
model = tf.keras.models.Model(inputs=input_data, outputs=output) | ||
|
||
np.random.seed(0) | ||
test_data = np.random.randn(2, 5, 10).astype(np.float32) | ||
output_data = model.predict(test_data) | ||
self.assertAllEqual(output_data.shape, (2, 5, 10)) | ||
|
||
expected_mean = np.mean(test_data, axis=-1, keepdims=True) | ||
expected_std = np.std(test_data, axis=-1, keepdims=True) | ||
expected_normalized = (test_data - expected_mean) / (expected_std + | ||
layer.epsilon) | ||
|
||
# Calculate expected output considering gamma and beta are default (i.e., gamma=1, beta=0) | ||
# 1e-3 is the default value for epsilon in LayerNormalization | ||
self.assertAllClose(output_data, expected_normalized, rtol=1e-3, atol=1e-3) | ||
|
||
def test_training_with_layer_normalization(self): | ||
input_dim = 10 | ||
num_samples = 100 | ||
output_dim = 1 | ||
|
||
np.random.seed(0) | ||
features = np.random.randn(num_samples, input_dim).astype(np.float32) | ||
labels = (np.sum(features, axis=1) + | ||
np.random.randn(num_samples) * 0.5).astype(np.float32).reshape( | ||
-1, 1) | ||
|
||
input_data = tf.keras.Input(shape=(input_dim,), dtype=tf.float32) | ||
normalized = de.keras.layers.LayerNormalization()(input_data) | ||
output = tf.keras.layers.Dense(output_dim)(normalized) | ||
model = tf.keras.models.Model(inputs=input_data, outputs=output) | ||
|
||
model.compile(optimizer='adam', loss='mean_squared_error') | ||
initial_weights = [layer.get_weights() for layer in model.layers] | ||
|
||
model.fit(features, labels, epochs=5, batch_size=10, verbose=0) | ||
|
||
updated_weights = [layer.get_weights() for layer in model.layers] | ||
|
||
for initial, updated in zip(initial_weights, updated_weights): | ||
for ini_w, upd_w in zip(initial, updated): | ||
self.assertGreater(np.sum(np.abs(ini_w - upd_w)), 0) | ||
|
||
predictions = model.predict(features) | ||
self.assertAllEqual(predictions.shape, (num_samples, output_dim)) | ||
self.assertGreater(np.std(predictions), 0.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters