-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
71 lines (58 loc) · 2.81 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from torch.utils.data import Dataset as dataset
from torchvision.transforms import ToTensor, RandomResizedCrop
from torchvision.transforms.functional import resized_crop
from os.path import join
from os import listdir
from PIL import Image
class Flickr(dataset):
def __init__(self, data_path):
self.data_path = data_path
self.origin_list = listdir(join(data_path, 'origin'))
self.transform = ToTensor()
def __len__(self):
return len(self.origin_list)
def process_img(self, origin, mask):
i, j, h, w = RandomResizedCrop.get_params(origin, scale=(0.5, 2.0), ratio=(3.0 / 4, 4.0 / 3))
origin = resized_crop(origin, i, j, h, w, size=(512, 512), interpolation=Image.NEAREST)
mask = resized_crop(mask, i, j, h, w, size=(512, 512), interpolation=Image.NEAREST)
origin = self.transform(origin)
mask = self.transform(mask)
mask = mask.squeeze(0).long()
return origin, mask
def __getitem__(self, index):
assert index < len(self), 'index out of range error'
file = self.origin_list[index]
origin = Image.open(join(self.data_path, 'origin', file)).convert('RGB')
mask = Image.open(join(self.data_path, 'mask', file.strip('jpg') + 'bmp')).convert('1')
data_dict = {}
data_dict['origin'], data_dict['mask'] = self.process_img(origin, mask)
return data_dict
class Icdar(dataset):
def __init__(self, data_path):
self.data_path = data_path
self.origin_list = listdir(join(self.data_path, 'origin'))
self.transform = ToTensor()
def __len__(self):
return len(self.origin_list)
def process_img(self, origin, mask):
i, j, h, w = RandomResizedCrop.get_params(origin, scale=(0.5, 2.0), ratio=(3.0 / 4, 4.0 / 3))
origin = resized_crop(origin, i, j, h, w, size=(640, 480), interpolation=Image.NEAREST)
mask = resized_crop(mask, i, j, h, w, size=(640, 480), interpolation=Image.NEAREST)
origin = self.transform(origin)
mask = self.transform(mask)
mask = mask.squeeze(0).long()
return origin, mask
def __getitem__(self, index):
assert index < len(self), 'index out of range error'
file = self.origin_list[index]
origin = Image.open(join(self.data_path, 'origin', file)).convert('RGB')
mask = Image.open(join(self.data_path, 'mask', file.strip('jpg') + 'png')).convert('L')
data_dict = {}
data_dict['origin'], data_dict['mask'] = self.process_img(origin, mask)
return data_dict
if __name__ == '__main__':
from torch.utils.data.dataloader import DataLoader as dataloader
# dl=dataloader(dataset=Flickr('./data'), batch_size=2, shuffle=True)
dl = dataloader(dataset=Icdar('./icdartest'), batch_size=2, shuffle=True)
for values in dl:
print(values.shape)