diff --git a/src/rewrite_rnn.cpp b/src/rewrite_rnn.cpp index 308a95d58ee..62532473cb0 100644 --- a/src/rewrite_rnn.cpp +++ b/src/rewrite_rnn.cpp @@ -969,10 +969,11 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const // process weight of the peephole instruction_ref pph = m.end(); - // if(args.size() == 8 and not args[7]->is_undefined()) - // { - // pph = args[7]; - // } + if(args.size() == 8 and not args[7]->is_undefined() and + not args[7]->get_shape().lens().empty()) + { + pph = args[7]; + } if(not is_forward and variable_seq_len) {