Skip to content

Commit

Permalink
fixing openmm_toplogy to align with original
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishan Taneja committed Jul 15, 2024
1 parent cccfd3c commit f515287
Showing 1 changed file with 60 additions and 15 deletions.
75 changes: 60 additions & 15 deletions cosolvkit/cosolvent_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _setup_new_topology(self, cosolvents_positions: dict, receptor_topology: app

molecules_positions = np.array(molecules_positions)*openmmunit.nanometer
new_top_openff = Topology.from_molecules(molecules)
new_top = self._to_openmm_topology(new_top_openff, starting_id=last_res_id, chain_id='S')
new_top = self._to_openmm_topology(new_top_openff, starting_id=last_res_id)
residues = list(new_top.residues())
for i in range(len(cosolvent_names)):
residues[i].name = cosolvent_names[i]
Expand All @@ -506,15 +506,13 @@ def _setup_new_topology(self, cosolvents_positions: dict, receptor_topology: app
new_mod.topology.setPeriodicBoxVectors(self._periodic_box_vectors)
return new_mod

def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id: str) -> app.Topology:
def _to_openmm_topology(self, off_topology: Topology, starting_id: int) -> app.Topology:
"""Converts an openff topology to openmm without specifying a different chain for each residue.
:param off_topology: Openff Topology
:type off_topology: openff.Topology
:param starting_id: starting index
:type starting_id: int
:param chain_id: chain_id for solvent molecules
:type chain_id: str
:raises RuntimeError: if something goes wrong
:return: openmm topology
:rtype: openmm.app.Topology
Expand All @@ -527,26 +525,67 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id

# Go through atoms in OpenFF to preserve the order.
omm_atoms = []
atom_insertion_code = " "
mol_num = 0
chain = omm_topology.addChain(chain_id)

last_chain = None
cnt = 0
# For each atom in each molecule, determine which chain/residue it should be a part of
for molecule in off_topology.molecules:
curr_residue = None
# No chain or residue can span more than one OFF molecule, so reset these to None for the first
# atom in each molecule.
last_residue = None
for atom in molecule.atoms:

atom_residue_name = molecule.name
atom_residue_number = str(starting_id+mol_num)

# If the residue number is undefined, assume a default of "0"
if "residue_number" in atom.metadata:
atom_residue_number = atom.metadata["residue_number"]
else:
atom_residue_number = str(starting_id+cnt)

# If the insertion code is undefined, assume a default of " "
if "insertion_code" in atom.metadata:
atom_insertion_code = atom.metadata["insertion_code"]
else:
atom_insertion_code = " "

# If the chain ID is undefined, assume a default of "X"
if "chain_id" in atom.metadata:
atom_chain_id = atom.metadata["chain_id"]
else:
atom_chain_id = "X"

# Determine whether this atom should be part of the last atom's chain, or if it
# should start a new chain
if last_chain is None:
chain = omm_topology.addChain(atom_chain_id)
elif last_chain.id == atom_chain_id:
chain = last_chain
else:
chain = omm_topology.addChain(atom_chain_id)
# Determine whether this atom should be a part of the last atom's residue, or if it
# should start a new residue
if curr_residue is None:
if last_residue is None:
residue = omm_topology.addResidue(
atom_residue_name,
chain,
id=atom_residue_number,
insertionCode=atom_insertion_code,
)
elif (
(last_residue.name == atom_residue_name)
and (int(last_residue.id) == int(atom_residue_number))
and (last_residue.insertionCode == atom_insertion_code)
and (chain.id == last_chain.id)
):
residue = last_residue
else:
residue = curr_residue
residue = omm_topology.addResidue(
atom_residue_name,
chain,
id=atom_residue_number,
insertionCode=atom_insertion_code,
)

# Add atom.
element = app.Element.getByAtomicNumber(atom.atomic_number)
Expand All @@ -556,9 +595,10 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id
assert off_topology.atom_index(atom) == int(omm_atom.id) - 1
omm_atoms.append(omm_atom)

curr_residue = residue
last_chain = chain
last_residue = residue

mol_num += 1
cnt += 1
# Add all bonds.
bond_types = {1: app.Single, 2: app.Double, 3: app.Triple}
for bond in molecule.bonds:
Expand Down Expand Up @@ -587,9 +627,14 @@ def _to_openmm_topology(self, off_topology: Topology, starting_id: int, chain_id
type=bond_type,
order=bond_order,
)


if off_topology.box_vectors is not None:
from openff.units.openmm import to_openmm

omm_topology.setPeriodicBoxVectors(to_openmm(off_topology.box_vectors))
return omm_topology



def _create_system(self, forcefield: app.forcefield, topology: app.Topology) -> System:
"""Returns system created from the Forcefield and the Topology.
Expand Down

0 comments on commit f515287

Please sign in to comment.