Skip to content

Commit

Permalink
tests interpret
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnovanHilten committed Oct 19, 2023
1 parent 8bf3e1a commit ea1ce13
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
2 changes: 1 addition & 1 deletion GenNet_utils/Create_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def create_network_from_csv(datapath,
model = K.layers.Reshape(input_shape=(inputsize,), target_shape=(inputsize, 1))(input_layer)

for i in range(len(columns) - 1):
matrix_ones = np.ones(len(network_csv[[columns[i], columns[i + 1]]]), np.bool)
matrix_ones = np.ones(len(network_csv[[columns[i], columns[i + 1]]]), bool)
matrix_coord = (network_csv[columns[i]].values, network_csv[columns[i + 1]].values)
if i == 0:
matrixshape = (inputsize, network_csv[columns[i + 1]].max() + 1)
Expand Down
2 changes: 1 addition & 1 deletion GenNet_utils/Interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_NID_scores(args):
print("Interpreting with NID:")
model, masks = load_trained_network(args)

print(model.summary())
G print(model.summary())

if args.layer == None:
if args.onehot == 1:
Expand Down
1 change: 1 addition & 0 deletions GenNet_utils/Train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,4 @@ def load_trained_network(args):




4 changes: 2 additions & 2 deletions requirements_GenNet.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ joblib>=0.16.0
Markdown>=3.2.1
matplotlib>=3.3.2
jupyter>=1.0.0
numpy==1.21
numpy>=1.21
pandas>=0.25.3
Pillow>=7.2.0
plotly>=4.12.0
Expand All @@ -21,5 +21,5 @@ statsmodels
shap
psutil
kaleido
tensorflow==2.2.0
tensorflow
bitarray
31 changes: 31 additions & 0 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import sys
import pandas as pd
import shutil
from os.path import dirname, abspath
import os
sys.path.insert(1, os.getcwd())
from GenNet_utils.Interpret import interpret

# TODO: add test without covariates
# TODO add test with covariates for regression + classification
# TODO add test with multiple genotype files.
# test randomnesss after .. epoch shuffles.
# ToDO add test for each file.

class ArgparseSimulator():
def __init__(self,
resultpath = "/trinity/home/avanhilten/repositories/epistasis/prototyping/GenNet_realLife_data/GenNet/results/GenNet_experiment_2_/",
type = "NID",
layer = None
):

self.resultpath = resultpath
self.type = type
self.layer = layer


args = ArgparseSimulator()
print("done")

interpret(args)

0 comments on commit ea1ce13

Please sign in to comment.