-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcorrupted_dataloader.py
45 lines (36 loc) · 1.58 KB
/
corrupted_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import torch
class CorruptedLabelDataLoader(torch.utils.data.DataLoader):
'''
This is a wrapper around a Pytorch DataLoader.
Simply wrap it around an instantiated DataLoader instance,
and use the resulting returned instance as if you are using
a normal DataLoader instance.
Example:
----------------
train_loader = ... # define `train_loader` as you normally would
train_loader = CorruptedLabelDataLoader(train_loader)
for (x, y) in train_loader:
...
----------------
Purpose of this wrapper:
Randomly permute the labels such that there is an
intentional mismatch between the images and labels.
'''
def __init__(self,
dataloader: torch.utils.data.DataLoader,
random_seed: int = 1) -> None:
self.dataloader = dataloader
np.random.seed(random_seed)
if 'targets' in self.dataloader.dataset.__dir__():
# The key `targets` is used in MNIST, CIFAR10, CIFAR100.
self.dataloader.dataset.targets = np.random.permutation(
self.dataloader.dataset.targets)
elif 'labels' in self.dataloader.dataset.__dir__():
# The key `labels` is used in STL10.
self.dataloader.dataset.labels = np.random.permutation(
self.dataloader.dataset.labels)
def __getattr__(self, name):
# This makes sure all methods and attributes of `dataloader`
# is inherited by `self`, unless otherwise overwritten.
return self.dataloader.__getattribute__(name)