-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaugmentation.py
133 lines (114 loc) · 6.17 KB
/
augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import random
import cv2
import numpy as np
from scipy import ndimage
def crop_center(img_np, cropx, cropy):
y, x = img_np.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img_np[starty:starty + cropy, startx:startx + cropx]
def augmentation(points_data, label_data):
"""
All of this augmentation is done in the 2D space, not on the 3D space.
"""
for b in range(points_data.shape[0]):
# Read
point_i = points_data[b, ...]
label_i = label_data[b, ...]
x = point_i[:, :, 0]
y = point_i[:, :, 1]
z = point_i[:, :, 2]
remission = point_i[:, :, 3]
depth = point_i[:, :, 4]
x = np.reshape(x, (64, 512))
y = np.reshape(y, (64, 512))
z = np.reshape(z, (64, 512))
remission = np.reshape(remission, (64, 512))
depth = np.reshape(depth, (64, 512))
label1 = label_i[:, :, 0]
label2 = label_i[:, :, 1]
label3 = label_i[:, :, 2]
label4 = label_i[:, :, 3]
label5 = label_i[:, :, 4]
label6 = label_i[:, :, 5]
# Augmentation 50% times
aug = random.random() > 0.5
if random.random() > 0.5 and aug: # Random flipping
x = np.fliplr(x)
y = np.fliplr(y)
# when you flip left_to_right, change the sign of the axis
y = - y
z = np.fliplr(z)
depth = np.fliplr(depth)
remission = np.fliplr(remission)
label1 = np.fliplr(label1)
label2 = np.fliplr(label2)
label3 = np.fliplr(label3)
label4 = np.fliplr(label4)
label5 = np.fliplr(label5)
label6 = np.fliplr(label6)
if random.random() > 0.5 and aug: # Random shifts
x_shift = random.randint(0, 512) - 256
y_shift = random.randint(0, 24) - 12
x = ndimage.shift(x, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
y = ndimage.shift(y, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
z = ndimage.shift(z, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
depth = ndimage.shift(depth, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
remission = ndimage.shift(remission, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label1 = ndimage.shift(label1, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label2 = ndimage.shift(label2, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label3 = ndimage.shift(label3, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label5 = ndimage.shift(label5, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label4 = ndimage.shift(label4, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
label6 = ndimage.shift(label6, (y_shift, x_shift), None, order=0, mode='constant', cval=0)
if random.random() > 0.5 and aug: # Random zooms (in/out)
zoom = random.random() + 1
if random.random() > 0.5:
zoom = 1. / zoom
x = ndimage.zoom(x, zoom, output=None, order=0, mode='constant', cval=0.0)
y = ndimage.zoom(y, zoom, output=None, order=0, mode='constant', cval=0.0)
z = ndimage.zoom(z, zoom, output=None, order=0, mode='constant', cval=0.0)
depth = ndimage.zoom(depth, zoom, output=None, order=0, mode='constant', cval=0.0)
remission = ndimage.zoom(remission, zoom, output=None, order=0, mode='constant', cval=0.0)
label1 = ndimage.zoom(label1, zoom, output=None, order=0, mode='constant', cval=0.0)
label2 = ndimage.zoom(label2, zoom, output=None, order=0, mode='constant', cval=0.0)
label3 = ndimage.zoom(label3, zoom, output=None, order=0, mode='constant', cval=0.0)
label4 = ndimage.zoom(label4, zoom, output=None, order=0, mode='constant', cval=0.0)
label5 = ndimage.zoom(label5, zoom, output=None, order=0, mode='constant', cval=0.0)
label6 = ndimage.zoom(label6, zoom, output=None, order=0, mode='constant', cval=0.0)
if zoom > 1:
x = crop_center(x, 512, 64)
y = crop_center(y, 512, 64)
z = crop_center(z, 512, 64)
depth = crop_center(depth, 512, 64)
remission = crop_center(remission, 512, 64)
label1 = crop_center(label1, 512, 64)
label2 = crop_center(label2, 512, 64)
label3 = crop_center(label3, 512, 64)
label4 = crop_center(label4, 512, 64)
label5 = crop_center(label5, 512, 64)
label6 = crop_center(label6, 512, 64)
else:
x = cv2.resize(x, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
y = cv2.resize(y, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
z = cv2.resize(z, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
depth = cv2.resize(depth, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
remission = cv2.resize(remission, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label1 = cv2.resize(label1, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label2 = cv2.resize(label2, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label3 = cv2.resize(label3, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label4 = cv2.resize(label4, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label5 = cv2.resize(label5, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label6 = cv2.resize(label6, dsize=(512, 64), interpolation=cv2.INTER_NEAREST)
label_data[b, :, :, 0] = label1
label_data[b, :, :, 1] = label2
label_data[b, :, :, 2] = label3
label_data[b, :, :, 3] = label4
label_data[b, :, :, 4] = label5
label_data[b, :, :, 5] = label6
points_data[b, :, :, 4] = depth
points_data[b, :, :, 3] = remission
points_data[b, :, :, 2] = z
points_data[b, :, :, 1] = y
points_data[b, :, :, 0] = x
return points_data, label_data