From 17ab229fc427f048ba82ffdd5ca242636e039a5e Mon Sep 17 00:00:00 2001 From: Zhaocheng Zhu Date: Sun, 16 Jul 2023 16:31:27 -0400 Subject: [PATCH] support customized edge feature in graph construction --- torchdrug/core/core.py | 1 - torchdrug/core/engine.py | 5 +++-- torchdrug/layers/geometry/graph.py | 18 +++++++++++++----- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py index 5ff7c9e..0227b66 100644 --- a/torchdrug/core/core.py +++ b/torchdrug/core/core.py @@ -86,7 +86,6 @@ def __setattr__(self, key, value): def __delattr__(self, key): if hasattr(self, "meta_dict") and key in self.meta_dict: del self.meta_dict[key] - del self.data_dict[key] super(_MetaContainer, self).__delattr__(self, key) def _setattr(self, key, value): diff --git a/torchdrug/core/engine.py b/torchdrug/core/engine.py index 99e5582..1bc83e8 100644 --- a/torchdrug/core/engine.py +++ b/torchdrug/core/engine.py @@ -225,20 +225,21 @@ def evaluate(self, split, log=True): return metric - def load(self, checkpoint, load_optimizer=True): + def load(self, checkpoint, load_optimizer=True, strict=True): """ Load a checkpoint from file. Parameters: checkpoint (file-like): checkpoint file load_optimizer (bool, optional): load optimizer state or not + strict (bool, optional): whether to strictly check the checkpoint matches the model parameters """ if comm.get_rank() == 0: logger.warning("Load checkpoint from %s" % checkpoint) checkpoint = os.path.expanduser(checkpoint) state = torch.load(checkpoint, map_location=self.device) - self.model.load_state_dict(state["model"]) + self.model.load_state_dict(state["model"], strict=strict) if load_optimizer: self.optimizer.load_state_dict(state["optimizer"]) diff --git a/torchdrug/layers/geometry/graph.py b/torchdrug/layers/geometry/graph.py index 9783a01..7aa16a8 100644 --- a/torchdrug/layers/geometry/graph.py +++ b/torchdrug/layers/geometry/graph.py @@ -26,6 +26,16 @@ class GraphConstruction(nn.Module, core.Configurable): 2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j` is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij), sequential_distance(i,j), spatial_distance(i,j)]``. + + .. note:: + You may customize your own edge features by inheriting this class and define a member function + for your features. Use ``edge_feature="my_feature"`` to call the following feature function. + + .. code:: python + + def edge_my_feature(self, graph, edge_list, num_relation): + ... + return feature # the first dimension must be ``graph.num_edge`` """ max_seq_dist = 10 @@ -43,7 +53,7 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ self.edge_layers = edge_layers self.edge_feature = edge_feature - def edge_residue_type(self, graph, edge_list): + def edge_residue_type(self, graph, edge_list, num_relation): node_in, node_out, _ = edge_list.t() residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out] in_residue_type = graph.residue_type[residue_in] @@ -103,10 +113,8 @@ def apply_edge_layer(self, graph): num_edges = edge2graph.bincount(minlength=graph.batch_size) offsets = (graph.num_cum_nodes - graph.num_nodes).repeat_interleave(num_edges) - if self.edge_feature == "residue_type": - edge_feature = self.edge_residue_type(graph, edge_list) - elif self.edge_feature == "gearnet": - edge_feature = self.edge_gearnet(graph, edge_list, num_relation) + if hasattr(self, "edge_%s" % self.edge_feature): + edge_feature = getattr(self, "edge_%s" % self.edge_feature)(graph, edge_list, num_relation) elif self.edge_feature is None: edge_feature = None else: