-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathContactDataSet.py
60 lines (49 loc) · 2.65 KB
/
ContactDataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
from torch.utils.data import Dataset
from utils import read_dataset, remove_outliers, merge_slip_with_fly, remove_features, normalize
class ContactDataSet(Dataset):
def __init__(self, csv_file, root_dir, transform = None, point_feet = False, add_noise = False):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
point_feet (boolean): Set to true if the robot does not have flat feet.
add_noise (boolean): Whether to perturb the dataset with noise or not.
"""
dataset = read_dataset(root_dir + "/" + csv_file)
labels = dataset[:,-1] # delete labels
dataset = np.delete(dataset, -1, axis=1)
# Merge slip labels with no_contact labels = Unstable contact
labels = merge_slip_with_fly(labels)
dataset, labels = remove_outliers(dataset, labels)
# Remove features and add noise in case of point-feet robot and optionally add noise to the
# data
if (point_feet):
dataset = remove_features([0,1,3,4,5], dataset)
# Remove features and add noise in case of point-feet robot
if (add_noise):
dataset[:,0:1] = add_noise(dataset[:,0:1], 0.6325) # Fz
dataset[:,1:4] = add_noise(dataset[:,1:4], 0.0078) # ax ay az
dataset[:,4:7] = add_noise(dataset[:,4:7], 0.00523) # wx wy wz
else:
if (add_noise):
dataset[:,:3] = add_noise(dataset[:,:3], 0.6325) # Fx Fy Fz
dataset[:,3:6] = add_noise(dataset[:,3:6], 0.03) # Tx Ty Tz
dataset[:,6:9] = add_noise(dataset[:,6:9], 0.0078) # ax ay az
dataset[:,9:12] = add_noise(dataset[:,9:12], 0.00523) # wx wy wz
# Normalize data in [0, 1]
for i in range(dataset.shape[1]):
dataset[:,i] = normalize(dataset[:,i], np.max(abs(dataset[:,i])))
dataset, labels = remove_outliers(dataset, labels)
self.data = torch.from_numpy(dataset).type(torch.float32)
self.labels = torch.nn.functional.one_hot(torch.from_numpy(labels).type(torch.int64),
num_classes=2).type(torch.float32)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx, :], self.labels[idx, :]