From 1e8f175ea7a96488e5d42eea5e6c41836b0f66d8 Mon Sep 17 00:00:00 2001 From: arnovanhilten Date: Fri, 20 Oct 2023 14:19:03 +0200 Subject: [PATCH] NID filters --- GenNet_utils/Interpret.py | 13 +++++++------ interpretation/NID.py | 3 +++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/GenNet_utils/Interpret.py b/GenNet_utils/Interpret.py index 6a4a4e3..33f9525 100644 --- a/GenNet_utils/Interpret.py +++ b/GenNet_utils/Interpret.py @@ -36,11 +36,10 @@ def get_weight_scores(args): def get_NID_scores(args): + print("Interpreting with NID:") model, masks = load_trained_network(args) - print(model.summary()) - if args.layer == None: if args.onehot == 1: interp_layer = 3 @@ -58,10 +57,12 @@ def get_NID_scores(args): w_in, w_out = Get_weight_tsang(model, interp_layer, masks) print("Computing interactions") - interaction_ranking1 = GenNet_pairwise_interactions_topn(w_in[:,1] ,w_out[:,1], masks, n=4) - interaction_ranking2 = GenNet_pairwise_interactions_topn(w_in[:,0] ,w_out[:,0], masks, n=4) - - interaction_ranking = interaction_ranking1.append(interaction_ranking2) + pairwise_interactions_dfs = [] + for filter in range(w_in.shape[1]): # for all the filters + pairwise_interactions = GenNet_pairwise_interactions_topn(w_in[:,filter] ,w_out[:,filter], masks, n="auto") + pairwise_interactions_dfs.append(pairwise_interactions) + + interaction_ranking = pd.concat([pairwise_interactions_dfs]) interaction_ranking = interaction_ranking.sort_values("strength", ascending =False) interaction_ranking.to_csv(args.resultpath + "/NID.csv") print("NID results are saved in", args.resultpath + "/NID.csv") diff --git a/interpretation/NID.py b/interpretation/NID.py index 7489621..b14d960 100644 --- a/interpretation/NID.py +++ b/interpretation/NID.py @@ -29,10 +29,13 @@ def GenNet_pairwise_interactions_topn(w_input, w_later, mask, n): n_genes = int(mask.shape[1]) + if n=="auto": + n=10e6 if (min(mask.sum(axis=0)) < n).any(): # Cannot get the top n if genes have less than n n = np.min(mask.sum(axis=0)) + num_combinations = int(np.round(np.math.factorial(n) / (np.math.factorial(2) * np.math.factorial((n - 2))) ))