Skip to content

Commit

Permalink
training
Browse files Browse the repository at this point in the history
  • Loading branch information
mjyoussef committed Apr 22, 2024
1 parent 293bf3a commit a318419
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 240 deletions.
14 changes: 14 additions & 0 deletions configs/AmazonPhoto.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"k": 2,
"p_damp": 0.1,
"p_trunc": 0.7,
"dropout": 0.1,
"hidden_dimension": 256,
"out_dimension": 256,
"margin": 1.0,
"lr": 0.001,
"batch_size": 48,
"epochs": 50,
"train_split": 0.7,
"val_split": 0.15
}
14 changes: 14 additions & 0 deletions configs/CoauthorCS.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"k": 2,
"p_damp": 0.35,
"p_trunc": 0.7,
"dropout": 0.1,
"hidden_dimension": 256,
"out_dimension": 256,
"margin": 1.0,
"lr": 0.001,
"batch_size": 48,
"epochs": 50,
"train_split": 0.7,
"val_split": 0.15
}
14 changes: 14 additions & 0 deletions configs/WikiCS.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"k": 2,
"p_damp": 0.1,
"p_trunc": 0.7,
"dropout": 0.1,
"hidden_dimension": 256,
"out_dimension": 256,
"margin": 1.0,
"lr": 0.001,
"batch_size": 48,
"epochs": 50,
"train_split": 0.7,
"val_split": 0.15
}
173 changes: 0 additions & 173 deletions embedding/load.py

This file was deleted.

147 changes: 80 additions & 67 deletions embedding/training.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,95 @@
import argparse
from typing import Dict, Any
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import Dataset
import torch.optim as optim
from utils import GraphPairsDataset, contrastive_loss, create_loaders
from torch_geometric.datasets import WikiCS, Amazon, Coauthor
from utils import *

def train(args: argparse.ArgumentParser) -> None:
# load in the graphs depending on the dataset
graphs = []
if (args.dataset == 'WikiCS'):
graphs = WikiCS('../datasets/wikics')
elif (args.dataset == 'AmazonPhoto'):
graphs = Amazon('../datasets/amazon_photo', 'photo')
from model import GraphSAGE
import json

def train_routine(params, model, train_loader, val_loader):
for epoch in range(params['epochs']):
# training
model.train()
optimizer = optim.Adam(model.parameters(), lr=params['lr'])
batch_idx = 0
for graph_pair, labels in train_loader:
optimizer.zero_grad()
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, params['margin'])
loss.backward()
optimizer.step()
print(f"Training: epoch {epoch}, batch {batch_idx}")
batch_idx += 1

# validation
model.eval()
avg_val_loss = 0
val_loader_size = len(val_loader)
for graph_pair, labels in val_loader:
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, params['margin'])
avg_val_loss += loss.item()

avg_val_loss /= val_loader_size
print(f"Average validation loss: {avg_val_loss}")

def test_routine(params, model, test_loader):
model.eval()
avg_test_loss = 0
test_loader_size = len(test_loader)
for graph_pair, labels in test_loader:
out1 = model(graph_pair[0])
out2 = model(graph_pair[1])
loss = contrastive_loss(out1, out2, labels, params['margin'])
avg_test_loss += loss.item()

avg_test_loss /= test_loader_size
print(f"Average test loss: {avg_test_loss}")

def train_and_test(dataset_name, params):
graphs = None
if (dataset_name == 'WikiCS'):
graphs = WikiCS('../datasets/WikiCS')
elif (dataset_name == 'AmazonPhoto'):
graphs = Amazon('../datasets/AmazonPhoto', 'Photo')
else:
graphs = Coauthor('../datasets/coauthor_cs', 'CS')
graphs = Coauthor('../datasets/CoauthorCS', 'CS')

obj = GraphPairsDataset(graphs, 2)







# update the input channels based on the graph's feature matrix
params['in_channels'] = graphs[0].x.shape[1]

# create the dataset
dataset = GraphPairsDataset(graphs, params['k'], params['p_damp'], params['p_trunc'], params['dropout'])

# create the loaders
train_loader, val_loader, test_loader = create_loaders(
dataset,
params['train_split'],
params['val_split'],
params['batch_size']
)

# initialize the model
model = GraphSAGE(
params['in_channels'],
params['hidden_dimension'],
params['out_dimension'],
params['k'],
)

train_routine(params, model, train_loader, val_loader)

test_routine(params, model, test_loader)

if __name__ == '__main__':
parser = argparse.ArgumentParser()

'''
Recommend parameters: https://arxiv.org/pdf/2010.14945.pdf
WikiCS:
k: 2
p_e: 0.3
p_f: 0.1
p_t: 0.7
lr: 0.01
activation: PReLu
hd: 256
AmazonPhoto:
k: 2
p_e: 0.4
p_f: 0.1
p_t: 0.7
lr: 0.1
activation: ReLu
hd: 256
CoauthorCS:
k: 2
p_e: 0.25
p_f: 0.35
p_t: 0.7
lr: 0.0005
activation: RReLu
hd: 256
'''

# get the dataset
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS'])
# parser.add_argument('--k', type=int, required=True)
# parser.add_argument('--p_f', type=float, required=True)
# parser.add_argument('--p_e', type=float, required=True)
# parser.add_argument('--p_t', type=float, required=True)
# parser.add_argument('--max_subgraphs', type=int, required=True)
# parser.add_argument('--hidden_dimensions', type=float, required=True)
# parser.add_argument('--activation', type=str, required=True, choices=['PReLu', 'RReLu', 'ReLu'])
# parser.add_argument('--lr', type=float, required=True)
# parser.add_argument('--epochs', type=int, required=True)
# parser.add_argument('--bs', type=int, required=True)
# parser.add_argument('--logging', type=bool, required=True)

args = parser.parse_args()

train(args)
# load the hyperparameters
params = json.load(f'../configs/{args.dataset}.json')

train_and_test(args.dataset, params)

0 comments on commit a318419

Please sign in to comment.