Skip to content

Commit

Permalink
Merge pull request #60 from NDoering99/main
Browse files Browse the repository at this point in the history
Updated docstrings and black formating of openmmdl_analysis
  • Loading branch information
talagayev authored Jan 22, 2024
2 parents 9632cee + 703efb1 commit 56500f1
Show file tree
Hide file tree
Showing 13 changed files with 2,100 additions and 1,345 deletions.
147 changes: 97 additions & 50 deletions openmmdl/openmmdl_analysis/barcode_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,21 @@ def barcodegeneration(df, interaction):
interaction (str): name of the interaction to generate a barcode for
Returns:
numpy array: returns an binary array of wit 1 representing the interaction is present in the corresponding frame
np.array: returns an binary array of wit 1 representing the interaction is present in the corresponding frame
"""
barcode = []
unique_frames = df['FRAME'].unique()

unique_frames = df["FRAME"].unique()

for frame in unique_frames:
frame_data = df[df['FRAME'] == frame]
frame_data = df[df["FRAME"] == frame]

if 1 in frame_data[interaction].values:
barcode.append(1)



else:
barcode.append(0)

return np.array(barcode)


Expand All @@ -44,10 +43,10 @@ def waterids_barcode_generator(df, interaction):
waterid_barcode = []
for index, row in df.iterrows():
if row[interaction] == 1:
water_id_list.append(int(float(row['WATER_IDX'])))
water_id_list.append(int(float(row["WATER_IDX"])))

barcode = barcodegeneration(df, interaction)

for value in barcode:
if value == 1:
waterid_barcode.append(water_id_list.pop(0))
Expand All @@ -62,34 +61,48 @@ def plot_barcodes(barcodes, save_path):
Args:
barcodes (list): list of np arrays containing the barcodes for each interaction
save_path (str): name of the file to save the picture to
"""
"""
if not barcodes:
print("No barcodes to plot.")
return

num_plots = len(barcodes)
num_cols = 1
num_rows = (num_plots + num_cols - 1) // num_cols

fig, axs = plt.subplots(num_rows, num_cols, figsize=(8.50, num_rows * 1))

# If only one row, axs is a single Axes object, not an array
if num_rows == 1:
axs = [axs]

for i, (title, barcode) in enumerate(barcodes.items()):
ax = axs[i]
ax.set_axis_off()
im = ax.imshow(barcode.reshape(1, -1), cmap='binary', aspect='auto', interpolation='nearest', vmin=0, vmax=1)
im = ax.imshow(
barcode.reshape(1, -1),
cmap="binary",
aspect="auto",
interpolation="nearest",
vmin=0,
vmax=1,
)

percent_occurrence = (barcode.sum() / len(barcode)) * 100
ax.text(1.05, 0.5, f"{percent_occurrence:.2f}%", transform=ax.transAxes, va='center', fontsize=8)
ax.text(
1.05,
0.5,
f"{percent_occurrence:.2f}%",
transform=ax.transAxes,
va="center",
fontsize=8,
)

ax.set_title(title, fontweight='bold', fontsize=8)
ax.set_title(title, fontweight="bold", fontsize=8)

os.makedirs(os.path.dirname("./Barcodes/"), exist_ok=True)
plt.tight_layout()
plt.savefig(f"./Barcodes/{save_path}", dpi=300, bbox_inches='tight')
plt.savefig(f"./Barcodes/{save_path}", dpi=300, bbox_inches="tight")


def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions):
Expand All @@ -104,7 +117,7 @@ def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interact
print("No Piecharts to plot.")
return

os.makedirs('Barcodes/Waterbridge_Piecharts', exist_ok=True)
os.makedirs("Barcodes/Waterbridge_Piecharts", exist_ok=True)
plt.figure(figsize=(6, 6))
for waterbridge_interaction in waterbridge_interactions:
plt.clf()
Expand All @@ -118,36 +131,61 @@ def plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interact
else:
waters_count[waterid] = 1

labels = [f'ID {id}' for id in waters_count.keys()]
labels = [f"ID {id}" for id in waters_count.keys()]
values = waters_count.values()

# Combine small categories into "Other" category
threshold = 7 # You can adjust this threshold. It is the percentage of the pie chart, not the total number
total_second_values = sum(value for _, value in waters_count.items())
small_ids = [id for id, value in waters_count.items() if (value / total_second_values) * 100 < threshold]
small_ids = [
id
for id, value in waters_count.items()
if (value / total_second_values) * 100 < threshold
]

if small_ids:
small_count = sum(count for id, count in waters_count.items() if id in small_ids)
values = [count if id not in small_ids else small_count for id, count in waters_count.items()]
labels = [f'ID {id}' if id not in small_ids else '' for id in waters_count.keys()]

plt.pie(values, labels=labels,
autopct=lambda pct: f'{pct:.1f}%\n({int(round(pct/100.0 * sum(values)))})',
shadow=False,
startangle=140)
plt.axis('equal')
plt.title(str(waterbridge_interaction), fontweight='bold')
small_count = sum(
count for id, count in waters_count.items() if id in small_ids
)
values = [
count if id not in small_ids else small_count
for id, count in waters_count.items()
]
labels = [
f"ID {id}" if id not in small_ids else "" for id in waters_count.keys()
]

plt.pie(
values,
labels=labels,
autopct=lambda pct: f"{pct:.1f}%\n({int(round(pct/100.0 * sum(values)))})",
shadow=False,
startangle=140,
)
plt.axis("equal")
plt.title(str(waterbridge_interaction), fontweight="bold")
# Manually create the legend with the correct labels
legend_labels = [f'ID {id}' for id in waters_count.keys()]
legend_labels = [f"ID {id}" for id in waters_count.keys()]
legend = plt.legend(legend_labels, loc="upper right", bbox_to_anchor=(1.2, 1))
plt.setp(legend.get_texts(), fontsize='small') # Adjust font size for legend
plt.text(0.5, 0, f"Total frames with waterbridge: {round(((sum(1 for val in waterid_barcode if val != 0) / len(waterid_barcode)) * 100), 2)}%", size=12, ha="center", transform=plt.gcf().transFigure)
plt.setp(legend.get_texts(), fontsize="small") # Adjust font size for legend
plt.text(
0.5,
0,
f"Total frames with waterbridge: {round(((sum(1 for val in waterid_barcode if val != 0) / len(waterid_barcode)) * 100), 2)}%",
size=12,
ha="center",
transform=plt.gcf().transFigure,
)
# Adjust the position of the subplots within the figure
plt.subplots_adjust(top=0.99, bottom=0.01) # You can change the value as needed
plt.savefig(f'Barcodes/Waterbridge_Piecharts/{waterbridge_interaction}.png', bbox_inches='tight', dpi=300)


def plot_bacodes_grouped(interactions, df_all, interaction_type):
plt.savefig(
f"Barcodes/Waterbridge_Piecharts/{waterbridge_interaction}.png",
bbox_inches="tight",
dpi=300,
)


def plot_bacodes_grouped(interactions, df_all, interaction_type, peptide=False):
"""generates barcode figures and groups them by ligandatom, aswell as total interaction barcode for a giveen lingenatom.
Args:
Expand All @@ -158,33 +196,42 @@ def plot_bacodes_grouped(interactions, df_all, interaction_type):
# get ligand atom information
ligatoms_dict = {}
for interaction in interactions:
ligatom = interaction.split('_')
ligatom = interaction.split("_")
ligatom.pop(0)
ligatom.pop(-1)
if interaction_type in ['acceptor', "donor", "waterbridge", "saltbridge_ni", "saltbridge_pi"]:
if interaction_type in [
"acceptor",
"donor",
"waterbridge",
"saltbridge_ni",
"saltbridge_pi",
]:
ligatom.pop(-1)
if interaction_type in ["saltbridge_ni", "saltbridge_pi"]:
ligatom.pop(-1)
ligatom = '_'.join(ligatom)
ligatom = "_".join(ligatom)
if ligatom not in ligatoms_dict:
ligatoms_dict[ligatom] = [interaction]
else:
ligatoms_dict[ligatom].append(interaction)
# plot barcodes
total_interactions ={}

# plot barcodes
total_interactions = {}
for ligatom in ligatoms_dict:
ligatom_interaction_barcodes = {}
for interaction in ligatoms_dict[ligatom]:
barcode = barcodegeneration(df_all, interaction)
ligatom_interaction_barcodes[interaction] = barcode
os.makedirs(f'./Barcodes/{ligatom}', exist_ok=True)
plot_barcodes(ligatom_interaction_barcodes, f'{ligatom}/{ligatom}_{interaction_type}_barcodes.png')

os.makedirs(f"./Barcodes/{ligatom}", exist_ok=True)
plot_barcodes(
ligatom_interaction_barcodes,
f"{ligatom}/{ligatom}_{interaction_type}_barcodes.png",
)

barcodes_list = list(ligatom_interaction_barcodes.values())
grouped_array = np.logical_or.reduce(barcodes_list)
grouped_array[np.all(np.vstack(barcodes_list) == 0, axis=0)] = 0
grouped_array = grouped_array.astype(int)
total_interactions[ligatom] = grouped_array
plot_barcodes(total_interactions, f'{interaction_type}_interactions.png')

plot_barcodes(total_interactions, f"{interaction_type}_interactions.png")
Loading

0 comments on commit 56500f1

Please sign in to comment.