-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
122 additions
and
240 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"k": 2, | ||
"p_damp": 0.1, | ||
"p_trunc": 0.7, | ||
"dropout": 0.1, | ||
"hidden_dimension": 256, | ||
"out_dimension": 256, | ||
"margin": 1.0, | ||
"lr": 0.001, | ||
"batch_size": 48, | ||
"epochs": 50, | ||
"train_split": 0.7, | ||
"val_split": 0.15 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"k": 2, | ||
"p_damp": 0.35, | ||
"p_trunc": 0.7, | ||
"dropout": 0.1, | ||
"hidden_dimension": 256, | ||
"out_dimension": 256, | ||
"margin": 1.0, | ||
"lr": 0.001, | ||
"batch_size": 48, | ||
"epochs": 50, | ||
"train_split": 0.7, | ||
"val_split": 0.15 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"k": 2, | ||
"p_damp": 0.1, | ||
"p_trunc": 0.7, | ||
"dropout": 0.1, | ||
"hidden_dimension": 256, | ||
"out_dimension": 256, | ||
"margin": 1.0, | ||
"lr": 0.001, | ||
"batch_size": 48, | ||
"epochs": 50, | ||
"train_split": 0.7, | ||
"val_split": 0.15 | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,82 +1,95 @@ | ||
import argparse | ||
from typing import Dict, Any | ||
from torch.utils.data import Dataset, DataLoader | ||
from torch_geometric.data.data import BaseData | ||
from torch_geometric.data.dataset import Dataset | ||
import torch.optim as optim | ||
from utils import GraphPairsDataset, contrastive_loss, create_loaders | ||
from torch_geometric.datasets import WikiCS, Amazon, Coauthor | ||
from utils import * | ||
|
||
def train(args: argparse.ArgumentParser) -> None: | ||
# load in the graphs depending on the dataset | ||
graphs = [] | ||
if (args.dataset == 'WikiCS'): | ||
graphs = WikiCS('../datasets/wikics') | ||
elif (args.dataset == 'AmazonPhoto'): | ||
graphs = Amazon('../datasets/amazon_photo', 'photo') | ||
from model import GraphSAGE | ||
import json | ||
|
||
def train_routine(params, model, train_loader, val_loader): | ||
for epoch in range(params['epochs']): | ||
# training | ||
model.train() | ||
optimizer = optim.Adam(model.parameters(), lr=params['lr']) | ||
batch_idx = 0 | ||
for graph_pair, labels in train_loader: | ||
optimizer.zero_grad() | ||
out1 = model(graph_pair[0]) | ||
out2 = model(graph_pair[1]) | ||
loss = contrastive_loss(out1, out2, labels, params['margin']) | ||
loss.backward() | ||
optimizer.step() | ||
print(f"Training: epoch {epoch}, batch {batch_idx}") | ||
batch_idx += 1 | ||
|
||
# validation | ||
model.eval() | ||
avg_val_loss = 0 | ||
val_loader_size = len(val_loader) | ||
for graph_pair, labels in val_loader: | ||
out1 = model(graph_pair[0]) | ||
out2 = model(graph_pair[1]) | ||
loss = contrastive_loss(out1, out2, labels, params['margin']) | ||
avg_val_loss += loss.item() | ||
|
||
avg_val_loss /= val_loader_size | ||
print(f"Average validation loss: {avg_val_loss}") | ||
|
||
def test_routine(params, model, test_loader): | ||
model.eval() | ||
avg_test_loss = 0 | ||
test_loader_size = len(test_loader) | ||
for graph_pair, labels in test_loader: | ||
out1 = model(graph_pair[0]) | ||
out2 = model(graph_pair[1]) | ||
loss = contrastive_loss(out1, out2, labels, params['margin']) | ||
avg_test_loss += loss.item() | ||
|
||
avg_test_loss /= test_loader_size | ||
print(f"Average test loss: {avg_test_loss}") | ||
|
||
def train_and_test(dataset_name, params): | ||
graphs = None | ||
if (dataset_name == 'WikiCS'): | ||
graphs = WikiCS('../datasets/WikiCS') | ||
elif (dataset_name == 'AmazonPhoto'): | ||
graphs = Amazon('../datasets/AmazonPhoto', 'Photo') | ||
else: | ||
graphs = Coauthor('../datasets/coauthor_cs', 'CS') | ||
graphs = Coauthor('../datasets/CoauthorCS', 'CS') | ||
|
||
obj = GraphPairsDataset(graphs, 2) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
# update the input channels based on the graph's feature matrix | ||
params['in_channels'] = graphs[0].x.shape[1] | ||
|
||
# create the dataset | ||
dataset = GraphPairsDataset(graphs, params['k'], params['p_damp'], params['p_trunc'], params['dropout']) | ||
|
||
# create the loaders | ||
train_loader, val_loader, test_loader = create_loaders( | ||
dataset, | ||
params['train_split'], | ||
params['val_split'], | ||
params['batch_size'] | ||
) | ||
|
||
# initialize the model | ||
model = GraphSAGE( | ||
params['in_channels'], | ||
params['hidden_dimension'], | ||
params['out_dimension'], | ||
params['k'], | ||
) | ||
|
||
train_routine(params, model, train_loader, val_loader) | ||
|
||
test_routine(params, model, test_loader) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
''' | ||
Recommend parameters: https://arxiv.org/pdf/2010.14945.pdf | ||
WikiCS: | ||
k: 2 | ||
p_e: 0.3 | ||
p_f: 0.1 | ||
p_t: 0.7 | ||
lr: 0.01 | ||
activation: PReLu | ||
hd: 256 | ||
AmazonPhoto: | ||
k: 2 | ||
p_e: 0.4 | ||
p_f: 0.1 | ||
p_t: 0.7 | ||
lr: 0.1 | ||
activation: ReLu | ||
hd: 256 | ||
CoauthorCS: | ||
k: 2 | ||
p_e: 0.25 | ||
p_f: 0.35 | ||
p_t: 0.7 | ||
lr: 0.0005 | ||
activation: RReLu | ||
hd: 256 | ||
''' | ||
|
||
# get the dataset | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', type=str, required=True, choices=['WikiCS', 'AmazonPhoto', 'CoauthorCS']) | ||
# parser.add_argument('--k', type=int, required=True) | ||
# parser.add_argument('--p_f', type=float, required=True) | ||
# parser.add_argument('--p_e', type=float, required=True) | ||
# parser.add_argument('--p_t', type=float, required=True) | ||
# parser.add_argument('--max_subgraphs', type=int, required=True) | ||
# parser.add_argument('--hidden_dimensions', type=float, required=True) | ||
# parser.add_argument('--activation', type=str, required=True, choices=['PReLu', 'RReLu', 'ReLu']) | ||
# parser.add_argument('--lr', type=float, required=True) | ||
# parser.add_argument('--epochs', type=int, required=True) | ||
# parser.add_argument('--bs', type=int, required=True) | ||
# parser.add_argument('--logging', type=bool, required=True) | ||
|
||
args = parser.parse_args() | ||
|
||
train(args) | ||
# load the hyperparameters | ||
params = json.load(f'../configs/{args.dataset}.json') | ||
|
||
train_and_test(args.dataset, params) |