Skip to content

Commit

Permalink
fix: fix stratified split dataset bug
Browse files Browse the repository at this point in the history
Referenced Issue: #2
  • Loading branch information
jerry-ryu committed Jan 9, 2023
1 parent b4eeb25 commit 8e2c539
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions model/font_classifier/dataset_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class FontDataset(Dataset):
def __init__(self, data_dir, val_ratio=0.2, is_train = True):
self.data_dir = data_dir
self.transform = None
self.is_train = is_train
if is_train:
self.val_ratio = val_ratio
self.is_train = is_train
self.setup()


Expand All @@ -38,13 +38,13 @@ def setup(self):
image_path.append(os.path.join(self.data_dir,profile,path))
image_label.append(idx)
tmp_all = set(range(len(image_path)))
tmp_train = set(random.sample(list(range(len(image_path))), int(len(image_path) * self.val_ratio)))
tmp_val = tmp_all - tmp_train
tmp_val = set(random.sample(list(range(len(image_path))), int(len(image_path) * self.val_ratio)))
tmp_train = tmp_all - tmp_val

self.image_paths_train.extend([image_path[x] for x in tmp_train])
self.image_labels_train.extend([image_label[x] for x in tmp_train])
self.image_paths_val.extend([image_path[x] for x in tmp_val])
self.image_labels_val.extend([image_path[x] for x in tmp_val])
self.image_labels_val.extend([image_label[x] for x in tmp_val])

def set_transform(self, transform):
self.transform = transform
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, img_paths, resize, mean=(0.548, 0.504, 0.479), std=(0.237, 0.
self.transform = transforms.Compose([
transforms.Resize(resize, Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
#transforms.Normalize(mean=mean, std=std),
])

def __getitem__(self, index):
Expand All @@ -104,7 +104,7 @@ def __init__(self, resize, mean, std, **args):
self.transform = transforms.Compose([
transforms.Resize(resize, Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
#transforms.Normalize(mean=mean, std=std),
])

def __call__(self, image):
Expand Down

0 comments on commit 8e2c539

Please sign in to comment.