diff --git a/requirements.txt b/requirements.txt index 683eb65..df9a237 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ matplotlib tqdm networkx ninja -jinja2 \ No newline at end of file +jinja2 +class-resolver diff --git a/setup.py b/setup.py index a924387..21356d3 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ "networkx", "ninja", "jinja2", + "class-resolver", ], python_requires=">=3.7,<3.9", classifiers=[ diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index 7d79f9b..ea18488 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -23,7 +23,7 @@ "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "DiffPool", "MinCutPool", - "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", + "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout", "ConditionalFlow", "NodeSampler", "EdgeSampler", "distribution", "functional", diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index cb1793d..521aaef 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.ChebNet") @@ -25,11 +27,11 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum`` and ``mean``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(ChebyshevConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -45,12 +47,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index 211f8d0..db2b5e7 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GAT") diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 968f2b1..679ae28 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -5,8 +5,8 @@ from torch import nn from torchdrug import core, layers -from torchdrug.layers import readout_resolver, Readout from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GCN") @@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(RelationalGraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - elif readout == "max": - self.readout = layers.MaxReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index 4b945cf..f32e2ca 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GIN") @@ -26,12 +28,12 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout="sum"): + readout: Hint[Readout] = "sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index 4d9a1ab..a1f9512 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -1,11 +1,13 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torch.nn import functional as F from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.NeuralFP") @@ -25,11 +27,11 @@ class NeuralFingerprint(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(NeuralFingerprint, self).__init__() if not isinstance(hidden_dims, Sequence): diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 3644e2e..7861a61 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.SchNet") @@ -25,6 +27,7 @@ class SchNet(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, diff --git a/torchdrug/tasks/pretrain.py b/torchdrug/tasks/pretrain.py index 5e8deec..8b9460e 100644 --- a/torchdrug/tasks/pretrain.py +++ b/torchdrug/tasks/pretrain.py @@ -1,13 +1,14 @@ import copy import torch +from class_resolver import Hint from torch import nn from torch.nn import functional as F from torch_scatter import scatter_max, scatter_min from torchdrug import core, tasks, layers from torchdrug.data import constant -from torchdrug.layers import functional +from torchdrug.layers import functional, readout_resolver, Readout from torchdrug.core import Registry as R @@ -169,9 +170,10 @@ class ContextPrediction(tasks.Task, core.Configurable): r2 (int, optional): outer radius for context graphs readout (nn.Module, optional): readout function over context anchor nodes num_negative (int, optional): number of negative samples per positive sample + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ - def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): + def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1): super(ContextPrediction, self).__init__() self.model = model self.k = k @@ -184,12 +186,8 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n self.context_model = copy.deepcopy(model) else: self.context_model = context_model - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + + self.readout = readout_resolver.make(readout) def substruct_and_context(self, graph): center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()