-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
81 lines (71 loc) · 2.81 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : encoder.py
@Time : 2020/03/09 18:47:50
@Author : jhhuang96
@Mail : hjh096@126.com
@Version : 1.0
@Description: encoder
'''
from torch import nn
from utils import make_layers
import torch
import logging
class Encoder(nn.Module):
def __init__(self, subnets, rnns):
super().__init__()
assert len(subnets) == len(rnns)
self.blocks = len(subnets)
for index, (params, rnn) in enumerate(zip(subnets, rnns), 1):
# index sign from 1
setattr(self, 'stage' + str(index), make_layers(params))
setattr(self, 'rnn' + str(index), rnn)
def forward_by_stage(self, inputs, subnet, rnn):
seq_number, batch_size, input_channel, height, width = inputs.size()
inputs = torch.reshape(inputs, (-1, input_channel, height, width))
inputs = subnet(inputs)
inputs = torch.reshape(inputs, (seq_number, batch_size, inputs.size(1),
inputs.size(2), inputs.size(3)))
outputs_stage, state_stage = rnn(inputs, None)
return outputs_stage, state_stage
def forward(self, inputs):
inputs = inputs.transpose(0, 1) # to S,B,1,64,64
hidden_states = []
logging.debug(inputs.size())
for i in range(1, self.blocks + 1):
inputs, state_stage = self.forward_by_stage(
inputs, getattr(self, 'stage' + str(i)),
getattr(self, 'rnn' + str(i)))
hidden_states.append(state_stage)
return tuple(hidden_states)
if __name__ == "__main__":
from net_params import convgru_encoder_params, convgru_decoder_params
from data.mm import MovingMNIST
encoder = Encoder(convgru_encoder_params[0],
convgru_encoder_params[1]).cuda()
trainFolder = MovingMNIST(is_train=True,
root='data/',
n_frames_input=10,
n_frames_output=10,
num_objects=[3])
trainLoader = torch.utils.data.DataLoader(
trainFolder,
batch_size=4,
shuffle=False,
)
device = torch.device("cuda:0")
for i, (idx, targetVar, inputVar, frozen, npzero) in enumerate(trainLoader):
inputs = inputVar.to(device) # B,S,1,64,64
state = encoder(inputs)
print(torch.cuda.device_count())
print("runing")
if i==1:break
print("frozen?",inputs[-1].shape)
# print("i: ", i, " ", "idx: ",idx)
print("inputs.shape",inputs.shape)
# print("inputs[0].shape",inputs[0].shape) # S,1,H,W Aim: 3S,1,H,W
# print("inputs[0,0].shape",inputs[0,0].shape)
print("forzen",frozen.shape)
print("npzero1",npzero.shape)
# print('enumerate(trainLoader)',enumerate(trainLoader))