-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCnn.py
35 lines (27 loc) · 1.21 KB
/
Cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch as th
import torch.nn as nn
import gym
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128, **kwargs):
super().__init__(observation_space, features_dim)
self.cnn = nn.Sequential(
nn.LayerNorm([3, 100, 156]),
nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0, bias=False),
nn.LayerNorm([32, 24, 38]),
nn.LeakyReLU(**kwargs),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0, bias=False),
nn.LayerNorm([64, 11, 18]),
nn.LeakyReLU(**kwargs),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, bias=False),
nn.LayerNorm([64, 9, 16]),
nn.LeakyReLU(**kwargs),
nn.Flatten(),
)
self.linear = nn.Sequential(
nn.Linear(9216, features_dim, bias=False),
nn.LayerNorm(features_dim),
nn.LeakyReLU(**kwargs),
)
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))