Skip to content

Commit

Permalink
Fix T5 large/T5-3B accuracy issue
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
parthchadha authored and rajeevsrao committed Oct 21, 2021
1 parent 6a97ec1 commit 9ec6eb6
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 50 deletions.
41 changes: 40 additions & 1 deletion demo/HuggingFace/NNDF/tensorrt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,53 @@ def clamp_weights_onnx(onnx_input_fpath: str, onnx_output_fpath: str, min: float
np.clip(node_attr.values, min, max, out=node_attr.values)

model = gs.export_onnx(graph)
onnx.save(model, onnx_output_fpath)
onnx.save(model, onnx_output_fpath, save_as_external_data=True)


def clamp_weights_onnx_to_fp16_bounds(onnx_input_fpath: str, onnx_output_fpath: str, ignore_nodes: List = None):
upper_bound = 65504
return clamp_weights_onnx(onnx_input_fpath, onnx_output_fpath, -upper_bound, upper_bound, ignore_nodes)


def move_t5_cast_op(onnx_input_fpath: str, onnx_output_fpath: str):
"""
T5 encoder and decoder have cast ops after residual add operation.
Moving the cast operation before add helps with FP16 accuracy as addition operation
can cause overflow in FP16.
"""

graph = gs.import_onnx(onnx.load(onnx_input_fpath))
cast_nodes = [node for node in graph.nodes if node.op == "Cast"]
for n in cast_nodes:
# Cast appears at the output of add and feeds into a Pow op.
if n.i().op == "Add":
found_pow = False
for o in n.outputs:
for o1 in o.outputs:
if o1.op == "Pow":
found_pow = True

if found_pow:
n.i().outputs = n.outputs
n.outputs.clear()

graph.cleanup().toposort()
add_nodes = [node for node in graph.nodes if node.op == "Add"]
for n in add_nodes:
if n.o().op == "Pow":
add_inputs = n.inputs
outs = []
for i in add_inputs:
identity_out = gs.Variable("identity_out" + i.name, dtype=np.float32)
new_cast = gs.Node(op="Cast", inputs=[i], outputs=[identity_out], attrs={"to": 1})
outs.append(identity_out)
graph.nodes.append(new_cast)
n.inputs = outs

graph.cleanup().toposort()
model = gs.export_onnx(graph)
onnx.save(model, onnx_output_fpath, save_as_external_data=True)

# Helper Classes
class TRTNativeRunner:
"""TRTNativeRunner avoids the high overheads with Polygraphy runner providing performance comparable to C++ implementation."""
Expand Down
5 changes: 3 additions & 2 deletions demo/HuggingFace/T5/T5ModelConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def from_inference_args(args: argparse.Namespace):

class T5ModelTRTConfig(NNConfig):

TARGET_MODELS = ["t5-small", "t5-base", "t5-large"]
NUMBER_OF_LAYERS = {TARGET_MODELS[0]: 6, TARGET_MODELS[1]: 12, TARGET_MODELS[2]: 24}
TARGET_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b"]
NUMBER_OF_LAYERS = {TARGET_MODELS[0]: 6, TARGET_MODELS[1]: 12, TARGET_MODELS[2]: 24, TARGET_MODELS[3]: 24}
MAX_SEQUENCE_LENGTH = {
TARGET_MODELS[0]: 512,
TARGET_MODELS[1]: 768,
TARGET_MODELS[2]: 1024,
TARGET_MODELS[3]: 1024,
}

NETWORK_FULL_NAME = "full"
Expand Down
99 changes: 52 additions & 47 deletions demo/HuggingFace/T5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# TRT-HuggingFace
from T5.T5ModelConfig import T5ModelTRTConfig
from NNDF.tensorrt_utils import clamp_weights_onnx_to_fp16_bounds
from NNDF.tensorrt_utils import clamp_weights_onnx_to_fp16_bounds, move_t5_cast_op
from NNDF.networks import NetworkMetadata
from NNDF.logger import G_LOGGER
from NNDF.models import (
Expand All @@ -48,55 +48,58 @@
)

def add_extra_fp32(network_definition):
def window(seq, n=2):
"Returns a sliding window (of width n) over data from the iterable"
" s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
it = iter(seq)
result = tuple(islice(it, n))
if len(result) == n:
yield result
for elem in it:
result = result[1:] + (elem,)
yield result

indices = list(range(0, network_definition[1].num_layers))
for i, i_1, i_2, i_3, i_4, i_5 in window(indices, 6):
l = network_definition[1].get_layer(i)
l_1 = network_definition[1].get_layer(i_1)
l_2 = network_definition[1].get_layer(i_2)
l_3 = network_definition[1].get_layer(i_3)
l_4 = network_definition[1].get_layer(i_4)
l_5 = network_definition[1].get_layer(i_5)

if not all([l.get_output(k).is_execution_tensor for k in range(l.num_outputs)]):
continue

if l.get_output_type(0) != trt.float32:
continue

if l.type == trt.LayerType.ELEMENTWISE and \
l_1.type == trt.LayerType.REDUCE and \
l_2.type == trt.LayerType.CONSTANT and \
l_4.type == trt.LayerType.ELEMENTWISE and \
l_5.type == trt.LayerType.UNARY:

l.__class__ = getattr(trt, "IElementWiseLayer")
if l.op == trt.ElementWiseOperation.POW:
"""
Force operations involved in layer norm to run in FP32 precision.
"""
pow_ops = {}
for layer_index, layer in enumerate(network_definition[1]):
if layer.type == trt.LayerType.IDENTITY:
all_fp32 = all([layer.output_type_is_set(o) and layer.get_output_type(o) == trt.float32 for o in range(layer.num_outputs)])
if all_fp32:
if layer.get_input(0).dtype == trt.float32:
layer.precision = trt.float32

if layer.type == trt.LayerType.ELEMENTWISE:
layer.__class__ = getattr(trt, "IElementWiseLayer")
if layer.op == trt.ElementWiseOperation.POW:
pow_ops[layer] = layer_index
layer.precision = trt.float32
layer.set_output_type(0, trt.float32)

for _, index in pow_ops.items():
# Iterate from few layers before pow to include residual add and cast op.
# Iterate till 10 layers after pow op to include all operations included in layer norm.
START_OFFSET = 4
END_OFFSET = 10
for i in range(index-START_OFFSET, index+END_OFFSET):
l = network_definition[1].get_layer(i)
if l.type == trt.LayerType.REDUCE:
l.precision = trt.float32
l.set_output_type(0, trt.float32)

l_1.precision = trt.float32
l_1.set_output_type(0, trt.float32)

l_4.__class__ = getattr(trt, "IElementWiseLayer")
if l_4.op == trt.ElementWiseOperation.SUM:
l_4.precision = trt.float32
l_4.set_output_type(0, trt.float32)

l_5.__class__ = getattr(trt, "IUnaryLayer")
if l_5.op == trt.UnaryOperation.SQRT:
l_5.precision = trt.float32
l_5.set_output_type(0, trt.float32)
if l.type == trt.LayerType.ELEMENTWISE:
l.__class__ = getattr(trt, "IElementWiseLayer")
if l.op == trt.ElementWiseOperation.SUM:
l.precision = trt.float32
l.set_output_type(0, trt.float32)

if l.type == trt.LayerType.UNARY:
l.__class__ = getattr(trt, "IUnaryLayer")
if l.op == trt.UnaryOperation.SQRT:
l.precision = trt.float32
l.set_output_type(0, trt.float32)

if l.type == trt.LayerType.ELEMENTWISE:
l.__class__ = getattr(trt, "IElementWiseLayer")
if l.op == trt.ElementWiseOperation.DIV:
l.precision = trt.float32
l.set_output_type(0, trt.float32)

if l.type == trt.LayerType.ELEMENTWISE:
l.__class__ = getattr(trt, "IElementWiseLayer")
if l.op == trt.ElementWiseOperation.PROD:
l.precision = trt.float32
l.set_output_type(0, trt.float32)

return network_definition

Expand Down Expand Up @@ -285,6 +288,7 @@ def _export_forward(*args, **kwargs):

if network_metadata.precision.fp16:
G_LOGGER.debug("Clamping FP16 weights for T5")
move_t5_cast_op(output_fpath, output_fpath)
clamp_weights_onnx_to_fp16_bounds(output_fpath, output_fpath)

return T5DecoderONNXFile(output_fpath, network_metadata)
Expand Down Expand Up @@ -332,6 +336,7 @@ def torch_to_onnx(

if network_metadata.precision.fp16:
G_LOGGER.debug("Clamping FP16 weights for T5")
move_t5_cast_op(output_fpath, output_fpath)
clamp_weights_onnx_to_fp16_bounds(output_fpath, output_fpath)

return T5EncoderONNXFile(output_fpath, network_metadata)

0 comments on commit 9ec6eb6

Please sign in to comment.