-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model.py
43 lines (31 loc) · 1.36 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch as th
from stable_baselines3 import PPO
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
model = PPO.load("./models/CartPole-v1/CartPole-v1.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
print(observation_size)
jit_path = "./models/CartPole-v1/model_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
# Load and test with torch
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)