From f5152875b66f62458c9de71f4c892113afe00707 Mon Sep 17 00:00:00 2001 From: Ishan Taneja Date: Mon, 15 Jul 2024 12:52:36 -0700 Subject: [PATCH] fixing openmm_toplogy to align with original --- cosolvkit/cosolvent_system.py | 75 ++++++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/cosolvkit/cosolvent_system.py b/cosolvkit/cosolvent_system.py index 671bce4..4cd7088 100644 --- a/cosolvkit/cosolvent_system.py +++ b/cosolvkit/cosolvent_system.py @@ -493,7 +493,7 @@ def _setup_new_topology(self, cosolvents_positions: dict, receptor_topology: app molecules_positions = np.array(molecules_positions)*openmmunit.nanometer new_top_openff = Topology.from_molecules(molecules) - new_top = self._to_openmm_topology(new_top_openff, starting_id=last_res_id, chain_id='S') + new_top = self._to_openmm_topology(new_top_openff, starting_id=last_res_id) residues = list(new_top.residues()) for i in range(len(cosolvent_names)): residues[i].name = cosolvent_names[i] @@ -506,15 +506,13 @@ def _setup_new_topology(self, cosolvents_positions: dict, receptor_topology: app new_mod.topology.setPeriodicBoxVectors(self._periodic_box_vectors) return new_mod - def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id: str) -> app.Topology: + def _to_openmm_topology(self, off_topology: Topology, starting_id: int) -> app.Topology: """Converts an openff topology to openmm without specifying a different chain for each residue. :param off_topology: Openff Topology :type off_topology: openff.Topology :param starting_id: starting index :type starting_id: int - :param chain_id: chain_id for solvent molecules - :type chain_id: str :raises RuntimeError: if something goes wrong :return: openmm topology :rtype: openmm.app.Topology @@ -527,26 +525,67 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id # Go through atoms in OpenFF to preserve the order. omm_atoms = [] - atom_insertion_code = " " - mol_num = 0 - chain = omm_topology.addChain(chain_id) + last_chain = None + cnt = 0 + # For each atom in each molecule, determine which chain/residue it should be a part of for molecule in off_topology.molecules: - curr_residue = None + # No chain or residue can span more than one OFF molecule, so reset these to None for the first + # atom in each molecule. + last_residue = None for atom in molecule.atoms: + atom_residue_name = molecule.name - atom_residue_number = str(starting_id+mol_num) + + # If the residue number is undefined, assume a default of "0" + if "residue_number" in atom.metadata: + atom_residue_number = atom.metadata["residue_number"] + else: + atom_residue_number = str(starting_id+cnt) + + # If the insertion code is undefined, assume a default of " " + if "insertion_code" in atom.metadata: + atom_insertion_code = atom.metadata["insertion_code"] + else: + atom_insertion_code = " " + + # If the chain ID is undefined, assume a default of "X" + if "chain_id" in atom.metadata: + atom_chain_id = atom.metadata["chain_id"] + else: + atom_chain_id = "X" + + # Determine whether this atom should be part of the last atom's chain, or if it + # should start a new chain + if last_chain is None: + chain = omm_topology.addChain(atom_chain_id) + elif last_chain.id == atom_chain_id: + chain = last_chain + else: + chain = omm_topology.addChain(atom_chain_id) # Determine whether this atom should be a part of the last atom's residue, or if it # should start a new residue - if curr_residue is None: + if last_residue is None: residue = omm_topology.addResidue( atom_residue_name, chain, id=atom_residue_number, insertionCode=atom_insertion_code, ) + elif ( + (last_residue.name == atom_residue_name) + and (int(last_residue.id) == int(atom_residue_number)) + and (last_residue.insertionCode == atom_insertion_code) + and (chain.id == last_chain.id) + ): + residue = last_residue else: - residue = curr_residue + residue = omm_topology.addResidue( + atom_residue_name, + chain, + id=atom_residue_number, + insertionCode=atom_insertion_code, + ) # Add atom. element = app.Element.getByAtomicNumber(atom.atomic_number) @@ -556,9 +595,10 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id assert off_topology.atom_index(atom) == int(omm_atom.id) - 1 omm_atoms.append(omm_atom) - curr_residue = residue + last_chain = chain + last_residue = residue - mol_num += 1 + cnt += 1 # Add all bonds. bond_types = {1: app.Single, 2: app.Double, 3: app.Triple} for bond in molecule.bonds: @@ -587,9 +627,14 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id type=bond_type, order=bond_order, ) - + + if off_topology.box_vectors is not None: + from openff.units.openmm import to_openmm + + omm_topology.setPeriodicBoxVectors(to_openmm(off_topology.box_vectors)) return omm_topology - + + def _create_system(self, forcefield: app.forcefield, topology: app.Topology) -> System: """Returns system created from the Forcefield and the Topology.