diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index be385ae67..9f39414e1 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -44,6 +44,11 @@ def main(): help="Include verbose logging", action="store_true", ) + parser.add_argument( + "--strict", + help="Enables strictness during export", + action="store_true", + ) args = cli.parse(parser) dataset = cli.get_input_dataset(args) @@ -113,6 +118,7 @@ def generate_batch_prefill(bs: int): name=f"prefill_bs{bs}", args=(tokens, seq_lens, seq_block_ids, cache_state), dynamic_shapes=dynamic_shapes, + strict=args.strict, ) def _(model, tokens, seq_lens, seq_block_ids, cache_state): sl = tokens.shape[1] @@ -170,6 +176,7 @@ def generate_batch_decode(bs: int): cache_state, ), dynamic_shapes=dynamic_shapes, + strict=args.strict, ) def _( model,