-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathCaraNet.py
138 lines (96 loc) · 4.88 KB
/
CaraNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 21 14:58:14 2021
@author: angelou
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pretrain.Res2Net_v1b import res2net50_v1b_26w_4s, res2net101_v1b_26w_4s
import math
import torchvision.models as models
from lib.conv_layer import Conv, BNPReLU
from lib.axial_atten import AA_kernel
from lib.context_module import CFPModule
from lib.partial_decoder import aggregation
import os
class caranet(nn.Module):
def __init__(self, channel=32):
super().__init__()
# ---- ResNet Backbone ----
self.resnet = res2net101_v1b_26w_4s(pretrained=True)
# Receptive Field Block
self.rfb2_1 = Conv(512, 32,3,1,padding=1,bn_acti=True)
self.rfb3_1 = Conv(1024, 32,3,1,padding=1,bn_acti=True)
self.rfb4_1 = Conv(2048, 32,3,1,padding=1,bn_acti=True)
# Partial Decoder
self.agg1 = aggregation(channel)
self.CFP_1 = CFPModule(32, d = 8)
self.CFP_2 = CFPModule(32, d = 8)
self.CFP_3 = CFPModule(32, d = 8)
###### dilation rate 4, 62.8
self.ra1_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra1_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra1_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
self.ra2_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra2_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra2_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
self.ra3_conv1 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra3_conv2 = Conv(32,32,3,1,padding=1,bn_acti=True)
self.ra3_conv3 = Conv(32,1,3,1,padding=1,bn_acti=True)
self.aa_kernel_1 = AA_kernel(32,32)
self.aa_kernel_2 = AA_kernel(32,32)
self.aa_kernel_3 = AA_kernel(32,32)
def forward(self, x):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x) # bs, 64, 88, 88
# ----------- low-level features -------------
x1 = self.resnet.layer1(x) # bs, 256, 88, 88
x2 = self.resnet.layer2(x1) # bs, 512, 44, 44
x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22
x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11
x2_rfb = self.rfb2_1(x2) # 512 - 32
x3_rfb = self.rfb3_1(x3) # 1024 - 32
x4_rfb = self.rfb4_1(x4) # 2048 - 32
decoder_1 = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 1,44,44
lateral_map_1 = F.interpolate(decoder_1, scale_factor=8, mode='bilinear')
# ------------------- atten-one -----------------------
decoder_2 = F.interpolate(decoder_1, scale_factor=0.25, mode='bilinear')
cfp_out_1 = self.CFP_3(x4_rfb) # 32 - 32
decoder_2_ra = -1*(torch.sigmoid(decoder_2)) + 1
aa_atten_3 = self.aa_kernel_3(cfp_out_1)
aa_atten_3_o = decoder_2_ra.expand(-1, 32, -1, -1).mul(aa_atten_3)
ra_3 = self.ra3_conv1(aa_atten_3_o) # 32 - 32
ra_3 = self.ra3_conv2(ra_3) # 32 - 32
ra_3 = self.ra3_conv3(ra_3) # 32 - 1
x_3 = ra_3 + decoder_2
lateral_map_2 = F.interpolate(x_3,scale_factor=32,mode='bilinear')
# ------------------- atten-two -----------------------
decoder_3 = F.interpolate(x_3, scale_factor=2, mode='bilinear')
cfp_out_2 = self.CFP_2(x3_rfb) # 32 - 32
decoder_3_ra = -1*(torch.sigmoid(decoder_3)) + 1
aa_atten_2 = self.aa_kernel_2(cfp_out_2)
aa_atten_2_o = decoder_3_ra.expand(-1, 32, -1, -1).mul(aa_atten_2)
ra_2 = self.ra2_conv1(aa_atten_2_o) # 32 - 32
ra_2 = self.ra2_conv2(ra_2) # 32 - 32
ra_2 = self.ra2_conv3(ra_2) # 32 - 1
x_2 = ra_2 + decoder_3
lateral_map_3 = F.interpolate(x_2,scale_factor=16,mode='bilinear')
# ------------------- atten-three -----------------------
decoder_4 = F.interpolate(x_2, scale_factor=2, mode='bilinear')
cfp_out_3 = self.CFP_1(x2_rfb) # 32 - 32
decoder_4_ra = -1*(torch.sigmoid(decoder_4)) + 1
aa_atten_1 = self.aa_kernel_1(cfp_out_3)
aa_atten_1_o = decoder_4_ra.expand(-1, 32, -1, -1).mul(aa_atten_1)
ra_1 = self.ra1_conv1(aa_atten_1_o) # 32 - 32
ra_1 = self.ra1_conv2(ra_1) # 32 - 32
ra_1 = self.ra1_conv3(ra_1) # 32 - 1
x_1 = ra_1 + decoder_4
lateral_map_5 = F.interpolate(x_1,scale_factor=8,mode='bilinear')
return lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1
if __name__ == '__main__':
ras = caranet().cuda()
input_tensor = torch.randn(1, 3, 352, 352).cuda()
out = ras(input_tensor)