diff --git a/conda/torchdrug/meta.yaml b/conda/torchdrug/meta.yaml index b366902..55604f0 100644 --- a/conda/torchdrug/meta.yaml +++ b/conda/torchdrug/meta.yaml @@ -1,6 +1,6 @@ package: name: torchdrug - version: 0.1.1 + version: 0.1.2 source: path: ../.. diff --git a/diff.txt b/diff.txt new file mode 100644 index 0000000..9a0efab --- /dev/null +++ b/diff.txt @@ -0,0 +1,202 @@ +diff --git a/conda/torchdrug/meta.yaml b/conda/torchdrug/meta.yaml +index b366902..55604f0 100644 +--- a/conda/torchdrug/meta.yaml ++++ b/conda/torchdrug/meta.yaml +@@ -1,6 +1,6 @@ + package: + name: torchdrug +- version: 0.1.1 ++ version: 0.1.2 + + source: + path: ../.. +diff --git a/doc/source/paper.rst b/doc/source/paper.rst +index ab8489c..c22a7ae 100644 +--- a/doc/source/paper.rst ++++ b/doc/source/paper.rst +@@ -86,9 +86,9 @@ Readout Layers + + 1. `Order Matters: Sequence to sequence for sets `_ + +- Oriol Vinyals, Samy Bengio, Manjunath Kudlur ++ Oriol Vinyals, Samy Bengio, Manjunath Kudlur + +- :class:`Set2Set ` ++ :class:`Set2Set ` + + Normalization Layers + ^^^^^^^^^^^^^^^^^^^^ +diff --git a/setup.py b/setup.py +index ddf3cdb..9da3d27 100644 +--- a/setup.py ++++ b/setup.py +@@ -13,7 +13,7 @@ if __name__ == "__main__": + long_description_content_type="text/markdown", + url="https://torchdrug.ai/", + author="TorchDrug Team", +- version="0.1.1", ++ version="0.1.2", + license="Apache-2.0", + keywords=["deep-learning", "pytorch", "drug-discovery"], + packages=setuptools.find_packages(), +diff --git a/torchdrug/__init__.py b/torchdrug/__init__.py +index 7058780..7dca7a0 100644 +--- a/torchdrug/__init__.py ++++ b/torchdrug/__init__.py +@@ -12,4 +12,4 @@ handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(format) + logger.addHandler(handler) + +-__version__ = "0.1.1" +\ No newline at end of file ++__version__ = "0.1.2" +\ No newline at end of file +diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py +index 1de312b..4c6ea18 100644 +--- a/torchdrug/core/core.py ++++ b/torchdrug/core/core.py +@@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()): + MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {}) + else: + MetaClass = _Configurable +- return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) ++ return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) +\ No newline at end of file +diff --git a/torchdrug/data/__init__.py b/torchdrug/data/__init__.py +index bfacd4c..66131f6 100644 +--- a/torchdrug/data/__init__.py ++++ b/torchdrug/data/__init__.py +@@ -1,3 +1,4 @@ ++from .dictionary import PerfectHash, Dictionary + from .graph import Graph, PackedGraph, cat + from .molecule import Molecule, PackedMolecule + from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \ +@@ -7,7 +8,7 @@ from . import constant + from . import feature + + __all__ = [ +- "Graph", "PackedGraph", "Molecule", "PackedMolecule", ++ "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", + "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", + "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", + "DataLoader", "graph_collate", "feature", "constant", +diff --git a/torchdrug/data/dataset.py b/torchdrug/data/dataset.py +index 6285244..8db50df 100644 +--- a/torchdrug/data/dataset.py ++++ b/torchdrug/data/dataset.py +@@ -171,16 +171,32 @@ class MoleculeDataset(torch_data.Dataset, core.Configurable): + def atom_types(self): + """All atom types.""" + atom_types = set() +- for i in range(len(self.data)): +- atom_types.update(self.get_item(i)["graph"].atom_type.tolist()) ++ ++ if getattr(self, "lazy", False): ++ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") ++ for smiles in self.smiles_list: ++ graph = data.Molecule.from_smiles(smiles, **self.kwargs) ++ atom_types.update(graph.atom_type.tolist()) ++ else: ++ for graph in self.data: ++ atom_types.update(graph.atom_type.tolist()) ++ + return sorted(atom_types) + + @utils.cached_property + def bond_types(self): + """All bond types.""" + bond_types = set() +- for i in range(len(self.data)): +- bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist()) ++ ++ if getattr(self, "lazy", False): ++ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") ++ for smiles in self.smiles_list: ++ graph = data.Molecule.from_smiles(smiles, **self.kwargs) ++ bond_types.update(graph.edge_list[:, 2].tolist()) ++ else: ++ for graph in self.data: ++ bond_types.update(graph.edge_list[:, 2].tolist()) ++ + return sorted(bond_types) + + def __len__(self): +diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py +index db16f7d..ef78c67 100644 +--- a/torchdrug/models/neurallp.py ++++ b/torchdrug/models/neurallp.py +@@ -104,7 +104,7 @@ class NeuralLogicProgramming(nn.Module, core.Configurable): + + h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) + hr_index = h_index * graph.num_relation + r_index +- hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True) ++ hr_index_set, hr_inverse = hr_index.unique(return_inverse=True) + h_index_set = hr_index_set // graph.num_relation + r_index_set = hr_index_set % graph.num_relation + +diff --git a/torchdrug/tasks/generation.py b/torchdrug/tasks/generation.py +index bb7ddc0..942e8e3 100644 +--- a/torchdrug/tasks/generation.py ++++ b/torchdrug/tasks/generation.py +@@ -803,7 +803,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): + self.batch_id += 1 + + # generation takes less time when early_stop=True +- graph = self.generate(len(batch["graph"]), max_resample=5, off_policy=True, max_step=40 * 2, verbose=1) ++ graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) + if graph.num_nodes.max() == 1: + raise ValueError("Generation results collapse to singleton molecules") + +@@ -1338,7 +1338,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): + self.best_results[task] = best_results + + @torch.no_grad() +- def generate(self, num_sample, max_resample=10, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): ++ def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): + is_training = self.training + self.eval() + +diff --git a/torchdrug/utils/comm.py b/torchdrug/utils/comm.py +index 0980131..817c281 100644 +--- a/torchdrug/utils/comm.py ++++ b/torchdrug/utils/comm.py +@@ -147,7 +147,7 @@ def reduce(obj, op="sum", dst=None): + Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + +- Examples:: ++ Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() +@@ -190,7 +190,7 @@ def stack(obj, dst=None): + obj (Object): any container object. Can be nested list, tuple or dict. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + +- Examples:: ++ Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() +@@ -229,7 +229,7 @@ def cat(obj, dst=None): + obj (Object): any container object. Can be nested list, tuple or dict. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + +- Examples:: ++ Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() +diff --git a/torchdrug/utils/io.py b/torchdrug/utils/io.py +index 29659cf..d573cde 100644 +--- a/torchdrug/utils/io.py ++++ b/torchdrug/utils/io.py +@@ -77,7 +77,7 @@ def capture_rdkit_log(): + """ + Context manager to capture all rdkit loggings. + +- Examples:: ++ Example:: + + >>> with utils.capture_rdkit_log() as log: + >>> ... diff --git a/doc/source/paper.rst b/doc/source/paper.rst index ab8489c..c22a7ae 100644 --- a/doc/source/paper.rst +++ b/doc/source/paper.rst @@ -86,9 +86,9 @@ Readout Layers 1. `Order Matters: Sequence to sequence for sets `_ - Oriol Vinyals, Samy Bengio, Manjunath Kudlur + Oriol Vinyals, Samy Bengio, Manjunath Kudlur - :class:`Set2Set ` + :class:`Set2Set ` Normalization Layers ^^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index ddf3cdb..9da3d27 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ long_description_content_type="text/markdown", url="https://torchdrug.ai/", author="TorchDrug Team", - version="0.1.1", + version="0.1.2", license="Apache-2.0", keywords=["deep-learning", "pytorch", "drug-discovery"], packages=setuptools.find_packages(), diff --git a/torchdrug/__init__.py b/torchdrug/__init__.py index 7058780..7dca7a0 100644 --- a/torchdrug/__init__.py +++ b/torchdrug/__init__.py @@ -12,4 +12,4 @@ handler.setFormatter(format) logger.addHandler(handler) -__version__ = "0.1.1" \ No newline at end of file +__version__ = "0.1.2" \ No newline at end of file diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py index 1de312b..4c6ea18 100644 --- a/torchdrug/core/core.py +++ b/torchdrug/core/core.py @@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()): MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {}) else: MetaClass = _Configurable - return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) + return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) \ No newline at end of file diff --git a/torchdrug/data/__init__.py b/torchdrug/data/__init__.py index bfacd4c..66131f6 100644 --- a/torchdrug/data/__init__.py +++ b/torchdrug/data/__init__.py @@ -1,3 +1,4 @@ +from .dictionary import PerfectHash, Dictionary from .graph import Graph, PackedGraph, cat from .molecule import Molecule, PackedMolecule from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \ @@ -7,7 +8,7 @@ from . import feature __all__ = [ - "Graph", "PackedGraph", "Molecule", "PackedMolecule", + "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", "DataLoader", "graph_collate", "feature", "constant", diff --git a/torchdrug/data/dataset.py b/torchdrug/data/dataset.py index 6285244..8db50df 100644 --- a/torchdrug/data/dataset.py +++ b/torchdrug/data/dataset.py @@ -171,16 +171,32 @@ def num_bond_type(self): def atom_types(self): """All atom types.""" atom_types = set() - for i in range(len(self.data)): - atom_types.update(self.get_item(i)["graph"].atom_type.tolist()) + + if getattr(self, "lazy", False): + warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") + for smiles in self.smiles_list: + graph = data.Molecule.from_smiles(smiles, **self.kwargs) + atom_types.update(graph.atom_type.tolist()) + else: + for graph in self.data: + atom_types.update(graph.atom_type.tolist()) + return sorted(atom_types) @utils.cached_property def bond_types(self): """All bond types.""" bond_types = set() - for i in range(len(self.data)): - bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist()) + + if getattr(self, "lazy", False): + warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") + for smiles in self.smiles_list: + graph = data.Molecule.from_smiles(smiles, **self.kwargs) + bond_types.update(graph.edge_list[:, 2].tolist()) + else: + for graph in self.data: + bond_types.update(graph.edge_list[:, 2].tolist()) + return sorted(bond_types) def __len__(self): diff --git a/torchdrug/datasets/uspto50k.py b/torchdrug/datasets/uspto50k.py index 980cc5e..7e2398c 100644 --- a/torchdrug/datasets/uspto50k.py +++ b/torchdrug/datasets/uspto50k.py @@ -103,6 +103,7 @@ def _get_difference(self, reactant, product): # check edges in the product product = product.directed() + # O(n^2) brute-force match is faster than O(nlogn) data.Graph.match for small molecules mapped_edge = product.edge_list.clone() mapped_edge[:, :2] = prod2react[mapped_edge[:, :2]] is_same_index = mapped_edge.unsqueeze(0) == reactant.edge_list.unsqueeze(1) @@ -123,8 +124,10 @@ def _get_reaction_center(self, reactant, product): if len(edge_added) > 0: if len(edge_added) == 1: # add a single edge - index = product.index(edge_added[0]) - assert len(index) == 1 + any = -torch.ones(1, 1, dtype=torch.long) + pattern = torch.cat([edge_added, any], dim=-1) + index, num_match = product.match(pattern) + assert num_match.item() == 1 edge_label[index] = 1 h, t = edge_added[0] reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]]) @@ -172,7 +175,10 @@ def _get_synthon(self, reactant, product): if len(edge_added) == 1: # add a single edge edge = edge_added[0] reverse_edge = edge.flip(0) - index = torch.cat([product.index(edge), product.index(reverse_edge)]) + any = -torch.ones(2, 1, dtype=torch.long) + pattern = torch.cat([edge, reverse_edge]) + pattern = torch.cat([pattern, any], dim=-1) + index, num_match = product.match(pattern) edge_mask = torch.ones(product.num_edge, dtype=torch.bool) edge_mask[index] = 0 product = product.edge_mask(edge_mask) diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py index db16f7d..ef78c67 100644 --- a/torchdrug/models/neurallp.py +++ b/torchdrug/models/neurallp.py @@ -104,7 +104,7 @@ def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) hr_index = h_index * graph.num_relation + r_index - hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True) + hr_index_set, hr_inverse = hr_index.unique(return_inverse=True) h_index_set = hr_index_set // graph.num_relation r_index_set = hr_index_set % graph.num_relation diff --git a/torchdrug/tasks/generation.py b/torchdrug/tasks/generation.py index bb7ddc0..942e8e3 100644 --- a/torchdrug/tasks/generation.py +++ b/torchdrug/tasks/generation.py @@ -803,7 +803,7 @@ def reinforce_forward(self, batch): self.batch_id += 1 # generation takes less time when early_stop=True - graph = self.generate(len(batch["graph"]), max_resample=5, off_policy=True, max_step=40 * 2, verbose=1) + graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) if graph.num_nodes.max() == 1: raise ValueError("Generation results collapse to singleton molecules") @@ -1338,7 +1338,7 @@ def update_best_result(self, graph, score, task): self.best_results[task] = best_results @torch.no_grad() - def generate(self, num_sample, max_resample=10, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): + def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): is_training = self.training self.eval() diff --git a/torchdrug/utils/comm.py b/torchdrug/utils/comm.py index 0980131..817c281 100644 --- a/torchdrug/utils/comm.py +++ b/torchdrug/utils/comm.py @@ -147,7 +147,7 @@ def reduce(obj, op="sum", dst=None): Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. - Examples:: + Example:: >>> # assume 4 workers >>> rank = comm.get_rank() @@ -190,7 +190,7 @@ def stack(obj, dst=None): obj (Object): any container object. Can be nested list, tuple or dict. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. - Examples:: + Example:: >>> # assume 4 workers >>> rank = comm.get_rank() @@ -229,7 +229,7 @@ def cat(obj, dst=None): obj (Object): any container object. Can be nested list, tuple or dict. dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. - Examples:: + Example:: >>> # assume 4 workers >>> rank = comm.get_rank() diff --git a/torchdrug/utils/io.py b/torchdrug/utils/io.py index 29659cf..d573cde 100644 --- a/torchdrug/utils/io.py +++ b/torchdrug/utils/io.py @@ -77,7 +77,7 @@ def capture_rdkit_log(): """ Context manager to capture all rdkit loggings. - Examples:: + Example:: >>> with utils.capture_rdkit_log() as log: >>> ...