Skip to content

Commit

Permalink
support customized edge feature in graph construction
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Jul 16, 2023
1 parent 91e9bd9 commit 17ab229
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
1 change: 0 additions & 1 deletion torchdrug/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __setattr__(self, key, value):
def __delattr__(self, key):
if hasattr(self, "meta_dict") and key in self.meta_dict:
del self.meta_dict[key]
del self.data_dict[key]
super(_MetaContainer, self).__delattr__(self, key)

def _setattr(self, key, value):
Expand Down
5 changes: 3 additions & 2 deletions torchdrug/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,21 @@ def evaluate(self, split, log=True):

return metric

def load(self, checkpoint, load_optimizer=True):
def load(self, checkpoint, load_optimizer=True, strict=True):
"""
Load a checkpoint from file.
Parameters:
checkpoint (file-like): checkpoint file
load_optimizer (bool, optional): load optimizer state or not
strict (bool, optional): whether to strictly check the checkpoint matches the model parameters
"""
if comm.get_rank() == 0:
logger.warning("Load checkpoint from %s" % checkpoint)
checkpoint = os.path.expanduser(checkpoint)
state = torch.load(checkpoint, map_location=self.device)

self.model.load_state_dict(state["model"])
self.model.load_state_dict(state["model"], strict=strict)

if load_optimizer:
self.optimizer.load_state_dict(state["optimizer"])
Expand Down
18 changes: 13 additions & 5 deletions torchdrug/layers/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class GraphConstruction(nn.Module, core.Configurable):
2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j`
is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij),
sequential_distance(i,j), spatial_distance(i,j)]``.
.. note::
You may customize your own edge features by inheriting this class and define a member function
for your features. Use ``edge_feature="my_feature"`` to call the following feature function.
.. code:: python
def edge_my_feature(self, graph, edge_list, num_relation):
...
return feature # the first dimension must be ``graph.num_edge``
"""

max_seq_dist = 10
Expand All @@ -43,7 +53,7 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
self.edge_layers = edge_layers
self.edge_feature = edge_feature

def edge_residue_type(self, graph, edge_list):
def edge_residue_type(self, graph, edge_list, num_relation):
node_in, node_out, _ = edge_list.t()
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
in_residue_type = graph.residue_type[residue_in]
Expand Down Expand Up @@ -103,10 +113,8 @@ def apply_edge_layer(self, graph):
num_edges = edge2graph.bincount(minlength=graph.batch_size)
offsets = (graph.num_cum_nodes - graph.num_nodes).repeat_interleave(num_edges)

if self.edge_feature == "residue_type":
edge_feature = self.edge_residue_type(graph, edge_list)
elif self.edge_feature == "gearnet":
edge_feature = self.edge_gearnet(graph, edge_list, num_relation)
if hasattr(self, "edge_%s" % self.edge_feature):
edge_feature = getattr(self, "edge_%s" % self.edge_feature)(graph, edge_list, num_relation)
elif self.edge_feature is None:
edge_feature = None
else:
Expand Down

0 comments on commit 17ab229

Please sign in to comment.