forked from SimingYan/HPNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
130 lines (101 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from trainer import Trainer
from option import build_option
from utils.loss_utils import compute_embedding_loss, compute_normal_loss, \
compute_param_loss, compute_nnl_loss, compute_miou, compute_type_miou_abc
from utils.main_utils import npy
from utils.abc_utils import mean_shift, compute_entropy, construction_affinity_matrix_type, \
construction_affinity_matrix_normal
import scipy.stats as stats
from src.mean_shift import MeanShift
class MyTrainer(Trainer):
def process_batch(self, batch_data_label, postprocess=False):
result = {}
inputs_xyz_th = (batch_data_label['gt_pc']).float().cuda().permute(0,2,1)
inputs_n_th = (batch_data_label['gt_normal']).float().cuda().permute(0,2,1)
if self.opt.input_normal:
affinity_feat, type_per_point, normal_per_point, param_per_point, sub_idx = self.model(inputs_xyz_th, inputs_n_th, postprocess=postprocess)
else:
affinity_feat, type_per_point, param_per_point, sub_idx = self.model(inputs_xyz_th, inputs_n_th, postprocess=postprocess)
result['types'] = torch.argmax(type_per_point, dim=-1)
result['params'] = param_per_point
result['gt_indices'] = sub_idx
inputs_xyz_sub = torch.gather(inputs_xyz_th, -1, sub_idx.unsqueeze(1).repeat(1,3,1))
N_gt = (batch_data_label['gt_normal']).float().cuda()
N_gt = torch.gather(N_gt, 1, sub_idx.unsqueeze(-1).repeat(1,1,3))
I_gt = torch.gather(batch_data_label['I_gt'], -1, sub_idx)
T_gt = torch.gather(batch_data_label['T_gt'], -1, sub_idx)
loss_dict = {}
if 'f' in self.opt.loss_class:
# network feature loss
feat_loss, pull_loss, push_loss = compute_embedding_loss(affinity_feat, I_gt)
loss_dict['feat_loss'] = feat_loss
if 'n' in self.opt.loss_class:
# normal angle loss
normal_loss = compute_normal_loss(normal_per_point, N_gt)
loss_dict['normal_loss'] = self.opt.normal_weight * normal_loss
if 'p' in self.opt.loss_class:
T_param_gt = torch.gather(batch_data_label['T_param'], 1, sub_idx.unsqueeze(-1).repeat(1,1,22))
# parameter loss
param_loss = compute_param_loss(param_per_point, T_gt, T_param_gt)
loss_dict['param_loss'] = self.opt.param_weight * param_loss
if 'r' in self.opt.loss_class:
# primitive nnl loss
type_loss = compute_nnl_loss(type_per_point, T_gt)
loss_dict['nnl_loss'] = self.opt.type_weight * type_loss
total_loss = 0
for key in loss_dict:
if 'loss' in key:
total_loss += loss_dict[key]
if postprocess:
affinity_matrix = construction_affinity_matrix_type(inputs_xyz_sub, type_per_point, param_per_point, self.opt.sigma)
affinity_matrix_normal = construction_affinity_matrix_normal(inputs_xyz_sub, N_gt, sigma=self.opt.normal_sigma, knn=self.opt.edge_knn)
obj_idx = batch_data_label['index'][0]
spec_embedding_list = []
weight_ent = []
# use network feature
feat_ent = self.opt.feat_ent_weight - float(npy(compute_entropy(affinity_feat)))
weight_ent.append(feat_ent)
spec_embedding_list.append(affinity_feat)
# use geometry distance feature
topk = self.opt.topK
e, v = torch.lobpcg(affinity_matrix, k=topk, niter=10)
v = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-16)
dis_ent = self.opt.dis_ent_weight - float(npy(compute_entropy(v)))
weight_ent.append(dis_ent)
spec_embedding_list.append(v)
# use edge feature
edge_topk = self.opt.edge_topK
e, v = torch.lobpcg(affinity_matrix_normal, k=edge_topk, niter=10)
v = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-16)
edge_ent = self.opt.edge_ent_weight - float(npy(compute_entropy(v)))
weight_ent.append(edge_ent)
spec_embedding_list.append(v)
# combine features
weighted_list = []
norm_weight_ent = weight_ent / np.linalg.norm(weight_ent)
for i in range(len(spec_embedding_list)):
weighted_list.append(spec_embedding_list[i] * weight_ent[i])
spectral_embedding = torch.cat(weighted_list, dim=-1)
#ms = MeanShift()
#_, _, spec_cluster_pred = ms.guard_mean_shift(spectral_embedding, 0.015,
# 50, kernel_type="gaussian")
spec_cluster_pred = mean_shift(spectral_embedding, bandwidth=self.opt.bandwidth)
cluster_pred = spec_cluster_pred
miou, pred_ind, gt_ind = compute_miou(spec_cluster_pred, I_gt)
loss_dict['miou'] = miou
miou = compute_type_miou_abc(type_per_point, T_gt, cluster_pred, I_gt)
loss_dict['type_miou'] = miou
result['labels'] = cluster_pred
matching = np.zeros(max(0, torch.max(cluster_pred) + 1), dtype=np.int32) - 1
if len(pred_ind) > 0:
matching[pred_ind] = gt_ind
result['matching'] = matching
return total_loss, loss_dict, result
if __name__=='__main__':
FLAGS = build_option()
trainer = MyTrainer(FLAGS)
trainer.train()