Skip to content

Commit

Permalink
NID filters
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnovanHilten committed Oct 20, 2023
1 parent 2208905 commit 1e8f175
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 7 additions & 6 deletions GenNet_utils/Interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def get_weight_scores(args):


def get_NID_scores(args):

print("Interpreting with NID:")
model, masks = load_trained_network(args)

print(model.summary())

if args.layer == None:
if args.onehot == 1:
interp_layer = 3
Expand All @@ -58,10 +57,12 @@ def get_NID_scores(args):
w_in, w_out = Get_weight_tsang(model, interp_layer, masks)
print("Computing interactions")

interaction_ranking1 = GenNet_pairwise_interactions_topn(w_in[:,1] ,w_out[:,1], masks, n=4)
interaction_ranking2 = GenNet_pairwise_interactions_topn(w_in[:,0] ,w_out[:,0], masks, n=4)

interaction_ranking = interaction_ranking1.append(interaction_ranking2)
pairwise_interactions_dfs = []
for filter in range(w_in.shape[1]): # for all the filters
pairwise_interactions = GenNet_pairwise_interactions_topn(w_in[:,filter] ,w_out[:,filter], masks, n="auto")
pairwise_interactions_dfs.append(pairwise_interactions)

interaction_ranking = pd.concat([pairwise_interactions_dfs])
interaction_ranking = interaction_ranking.sort_values("strength", ascending =False)
interaction_ranking.to_csv(args.resultpath + "/NID.csv")
print("NID results are saved in", args.resultpath + "/NID.csv")
Expand Down
3 changes: 3 additions & 0 deletions interpretation/NID.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ def GenNet_pairwise_interactions_topn(w_input, w_later, mask, n):

n_genes = int(mask.shape[1])

if n=="auto":
n=10e6

if (min(mask.sum(axis=0)) < n).any(): # Cannot get the top n if genes have less than n
n = np.min(mask.sum(axis=0))



num_combinations = int(np.round(np.math.factorial(n) / (np.math.factorial(2) * np.math.factorial((n - 2))) ))

Expand Down

0 comments on commit 1e8f175

Please sign in to comment.