-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
38 lines (27 loc) · 914 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import torch
import pandas as pd
class Preprocess:
def __init__(self, args):
self.args = args
self.train_data = None
self.test_data = None
def load_data(self, file_name):
csv_file_name = os.path.join(self.args.data_dir, file_name)
df = pd.read_csv(csv_file_name)
return df.values
def load_train_data(self):
self.train_data = self.load_data('train_data.csv')
def load_test_data(self):
self.test_data = self.load_data('test_data.csv')
class YNAT_dataset(torch.utils.data.Dataset):
def __init__(self, args, data, is_inference):
self.args = args
self.data = data
self.is_inference = is_inference
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
element = [row[i] for i in range(len(row))]
return element