Skip to content

Commit

Permalink
training
Browse files Browse the repository at this point in the history
  • Loading branch information
mjyoussef committed Apr 22, 2024
1 parent 5f64180 commit 293bf3a
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 206 deletions.
220 changes: 149 additions & 71 deletions embedding/load.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,173 @@
import argparse
from utils import *
import torch.optim as optim
from utils import GraphPairsDataset, contrastive_loss, create_loaders
from torch_geometric.datasets import WikiCS, Amazon, Coauthor
from model import GraphSAGE

def train_routine(args, model, train_loader, val_loader):
for epoch in range(args.epochs):
# training
model.train()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
batch_idx = 0
for graph_pair, labels in train_loader:
optimizer.zero_grad()
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, args.margin)
loss.backward()
optimizer.step()
print(f"Training: epoch {epoch}, batch {batch_idx}")
batch_idx += 1

# validation
model.eval()
avg_val_loss = 0
val_loader_size = len(val_loader)
for graph_pair, labels in val_loader:
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, args.margin)
avg_val_loss += loss.item()

avg_val_loss /= val_loader_size
print(f"Average validation loss: {avg_val_loss}")

def test_routine(args, model, test_loader):
model.eval()
avg_test_loss = 0
test_loader_size = len(test_loader)
for graph_pair, labels in test_loader:
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, args.margin)
avg_test_loss += loss.item()

avg_test_loss /= test_loader_size
print(f"Average test loss: {avg_test_loss}")

def train_and_test(args, train_loader, val_loader, test_loader):
if (args.dataset == 'WikiCS'):
args, train_loader, val_loader, test_loader = wikics(args)
elif (args.dataset == 'AmazonPhoto'):
args, train_loader, val_loader, test_loader = amazon_photo(args)
else:
args, train_loader, val_loader, test_loader = coauthor_cs(args)

# initialize the model
model = GraphSAGE(
args.in_channels,
args.hidden_channels,
args.out_channels,
args.k
)

train_routine(args, model, train_loader, val_loader)

test_routine(args, model, test_loader)

def wikics(args):
k = args.k
p_damp = args.p_damp
p_trunc = args.p_trunc
dropout = args.dropout

# set optional parameters
if (p_damp is None):
p_damp = 0.1

# store the dataset
graphs = WikiCS('../datasets/wikics')

# handle optional parameters w/o defaults
if (args.p_damp is None):
args.p_damp = 0.1

if (p_trunc is None):
p_trunc = 0.7
if (args.p_trunc is None):
args.p_trunc = 0.7

print(k, p_damp, p_trunc, dropout)
if (args.in_dimension is None):
args.in_dimension = graphs[0].x.shape[1]

graphs = WikiCS('../datasets/wikics')
print(graphs)

generate_samples(
"./data/wikics",
graphs,
k,
p_damp,
p_trunc,
dropout=dropout
)
if (args.hidden_dimension is None):
args.hidden_dimension = 256
args.out_dimension = args.hidden_dimension

if (args.lr is None):
args.lr = 0.001

if (args.batch_size is None):
args.batch_size = 64

if (args.epochs is None):
args.epochs = 50

dataset = GraphPairsDataset(graphs, args.k, args.p_damp, args.p_trunc, args.dropout)

return args, create_loaders(dataset, args.train_split, args.val_split, args.batch_size)

def amazon_photo(args):
k = args.k
p_damp = args.p_damp
p_trunc = args.p_trunc
dropout = args.dropout

# set optional parameters
if (p_damp is None):
p_damp = 0.1
pass

def coauthor_cs(args):
pass

# def amazon_photo(args):
# buffer_size = args.buffer_size
# k = args.k
# p_damp = args.p_damp
# p_trunc = args.p_trunc
# dropout = args.dropout

# # set optional parameters
# if (p_damp is None):
# p_damp = 0.1

if (p_trunc is None):
p_trunc = 0.7
# if (p_trunc is None):
# p_trunc = 0.7

graphs = Amazon('../datasets/amazon_photo', 'photo')

generate_samples(
"./data/amazon_photo",
graphs,
k,
p_damp,
p_trunc,
dropout=dropout
)
# graphs = Amazon('../datasets/amazon_photo', 'photo')

def coauthor_cs(args):
k = args.k
p_damp = args.p_damp
p_trunc = args.p_trunc
dropout = args.dropout

# set optional parameters
if (p_damp is None):
p_damp = 0.35
# generate_samples(
# f"./data/amazon_photo/{args.version}",
# buffer_size,
# graphs,
# k,
# p_damp,
# p_trunc,
# dropout=dropout
# )

# def coauthor_cs(args):
# buffer_size = args.buffer_size
# k = args.k
# p_damp = args.p_damp
# p_trunc = args.p_trunc
# dropout = args.dropout

# # set optional parameters
# if (p_damp is None):
# p_damp = 0.35

if (p_trunc is None):
p_trunc = 0.7
# if (p_trunc is None):
# p_trunc = 0.7

graphs = Coauthor('../datasets/coauthor_cs', 'CS')

generate_samples(
"./data/coauthor_cs",
graphs,
k,
p_damp,
p_trunc,
dropout=dropout
)
# graphs = Coauthor('../datasets/coauthor_cs', 'CS')

# generate_samples(
# f"./data/coauthor_cs/{args.version}",
# buffer_size,
# graphs,
# k,
# p_damp,
# p_trunc,
# dropout=dropout
# )

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS'])
parser.add_argument('--version', type=int, required=True)
parser.add_argument('--k', type=int, default=2, required=False)
parser.add_argument('--p_damp', type=float, default=None, required=False)
parser.add_argument('--p_trunc', type=float, default=None, required=False)
parser.add_argument('--dropout', type=float, default=0.1, required=False)
parser.add_argument('--dropout', type=float, default=0.2, required=False)
parser.add_argument('--train_split', type=float, default=0.7, required=False)
parser.add_argument('--val_split', type=float, default=0.15, required=False)
parser.add_argument('--margin', type=float, default=1.0, required=False)

args = parser.parse_args()

if (args.dataset == 'WikiCS'):
wikics(args)
elif (args.dataset == 'AmazonPhoto'):
amazon_photo(args)
else:
coauthor_cs(args)
train_and_test(args)
48 changes: 28 additions & 20 deletions embedding/training.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import argparse
from typing import Dict, Any
from torch_geometric.datasets import WikiCS
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import Dataset
from torch_geometric.datasets import WikiCS, Amazon, Coauthor
from utils import *

def load_data_samples(dataset):
pass
def train(args: argparse.ArgumentParser) -> None:
# load in the graphs depending on the dataset
graphs = []
if (args.dataset == 'WikiCS'):
graphs = WikiCS('../datasets/wikics')
elif (args.dataset == 'AmazonPhoto'):
graphs = Amazon('../datasets/amazon_photo', 'photo')
else:
graphs = Coauthor('../datasets/coauthor_cs', 'CS')

obj = GraphPairsDataset(graphs, 2)











def test_data_loader():
graphs = WikiCS('../datasets/wikics')
generate_samples(
'./data/wikics',
graphs,
2,
0.1,
0.7,
dropout=0.1,
)

def train(args: argparse.ArgumentParser) -> None:
pass

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -54,9 +64,7 @@ def train(args: argparse.ArgumentParser) -> None:
'''

test_data_loader()

# parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS'])
parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS'])
# parser.add_argument('--k', type=int, required=True)
# parser.add_argument('--p_f', type=float, required=True)
# parser.add_argument('--p_e', type=float, required=True)
Expand All @@ -69,6 +77,6 @@ def train(args: argparse.ArgumentParser) -> None:
# parser.add_argument('--bs', type=int, required=True)
# parser.add_argument('--logging', type=bool, required=True)

# args = parser.parse_args()
args = parser.parse_args()

# train(args.dataset)
train(args)
Loading

0 comments on commit 293bf3a

Please sign in to comment.