Skip to content

Commit

Permalink
NID pre onehot
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnovanHilten committed Oct 18, 2023
1 parent 5fd148d commit 5268526
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion GenNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self):
self.make_parser_topology(parser_topology)

parser_interpret = subparsers.add_parser("interpret", help="Post-hoc interpretation analysis on the network")
self.parser_interpret(parser_interpret)
self.make_parser_interpret(parser_interpret)

self.parser = parser

Expand Down
12 changes: 9 additions & 3 deletions GenNet_utils/Interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

from interpretation.weight_importance import make_importance_values_input
from interpretation.NID import Get_weight_tsang, GenNet_pairwise_interactions_topn
from interpretation.RLIPP import calculate_RLIPP

from GenNet_utils.Utility_functions import get_SLURM_id, evaluate_performance
from GenNet_utils.Train_network import load_trained_network

def interpret(args):
Expand Down Expand Up @@ -38,27 +36,35 @@ def get_weight_scores(args):


def get_NID_scores(args):
print("Interpreting with NID:")
model, masks = load_trained_network(args)

if args.layer == "None":
print(model.summary())

if args.layer == None:
if args.one_hot == 1:
interp_layer = 3
else:
interp_layer = 2
else:
interp_layer = args.layer

print("Interrpeting layer", interp_layer)
if os.path.exists(args.resultpath + "/NID.csv"):
print('RLIPP Done')
interaction_ranking = pd.read_csv(args.resultpath + "/NID.csv")
else:
print("Obtaining the weights")
w_in, w_out = Get_weight_tsang(model, interp_layer, masks)
print("Computing interactions")

interaction_ranking1 = GenNet_pairwise_interactions_topn(w_in[:,1] ,w_out[:,1], masks, n=4)
interaction_ranking2 = GenNet_pairwise_interactions_topn(w_in[:,0] ,w_out[:,0], masks, n=4)

interaction_ranking = interaction_ranking1.append(interaction_ranking2)
interaction_ranking = interaction_ranking.sort_values("strength", ascending =False)
interaction_ranking.to_csv(args.resultpath + "/NID.csv")
print("NID results are saved in", args.resultpath + "/NID.csv")

return interaction_ranking

Expand Down
26 changes: 14 additions & 12 deletions GenNet_utils/Train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def train_model(args):
model = None
masks = None

args.datapath = args.path


if args.genotype_path == "undefined":
args.genotype_path = args.path

Expand Down Expand Up @@ -138,8 +141,6 @@ def train_model(args):
plot_loss_function(args.resultpath)
model.load_weights(args.resultpath + '/bestweights_job.h5')
print("Finished")


save_train_arguments(args)


Expand Down Expand Up @@ -224,11 +225,10 @@ def train_model(args):
importance_csv = create_importance_csv(args.datapath, model, masks)
importance_csv.to_csv(args.resultpath + "connection_weights.csv")


def get_network(args):
"""needs the following inputs
args.inputsize
args.l1_value
args.L1
args.L1_act
args.datapath
args.genotype_path
Expand All @@ -237,40 +237,42 @@ def get_network(args):
"""
regression = args.regression if hasattr(args, 'regression') else False

args.L1 = args.regression if hasattr(args, 'regression') else False

if args.network_name == "lasso" and not regression:
print("lasso network")
model, masks = lasso(inputsize=args.inputsize, l1_value=args.l1_value, L1_act=args.L1_act if hasattr(args, 'L1_act') else None)
model, masks = lasso(inputsize=args.inputsize, l1_value=args.L1, L1_act=args.L1_act if hasattr(args, 'L1_act') else None)

elif args.network_name == "sparse_directed_gene_l1" and not regression:
print("sparse_directed_gene_l1 network")
model, masks = sparse_directed_gene_l1(inputsize=args.inputsize, l1_value=args.l1_value)
model, masks = sparse_directed_gene_l1(inputsize=args.inputsize, l1_value=args.L1)


elif args.network_name == "regression_height" and regression:
print("regression_height network")
model, masks = regression_height(inputsize=args.inputsize, l1_value=args.l1_value)
model, masks = regression_height(inputsize=args.inputsize, l1_value=args.L1)

elif args.network_name == "gene_network_multiple_filters":
print("gene_network_multiple_filters network")
model, masks = gene_network_multiple_filters(datapath=args.datapath, inputsize=args.inputsize, genotype_path=args.genotype_path,
l1_value=args.l1_value, L1_act=args.L1_act,
l1_value=args.L1, L1_act=args.L1_act,
regression=regression, num_covariates=args.num_covariates,
filters=args.filters)

elif args.network_name == "gene_network_snp_gene_filters":
print("gene_network_snp_gene_filters network")
model, masks = gene_network_snp_gene_filters(datapath=args.datapath, inputsize=args.inputsize, genotype_path=args.genotype_path,
l1_value=args.l1_value, L1_act=args.L1_act,
l1_value=args.L1, L1_act=args.L1_act,
regression=regression, num_covariates=args.num_covariates,
filters=args.filters)
else:
if os.path.exists(args.datapath + "/topology.csv"):
model, masks = create_network_from_csv(datapath=args.datapath, inputsize=args.inputsize, genotype_path=args.genotype_path,
l1_value=args.l1_value, L1_act=args.L1_act, regression=regression,
l1_value=args.L1, L1_act=args.L1_act, regression=regression,
num_covariates=args.num_covariates)
elif len(glob.glob(args.datapath + "/*.npz")) > 0:
model, masks = create_network_from_npz(datapath=args.datapath, inputsize=args.inputsize, genotype_path=args.genotype_path,
l1_value=args.l1_value, L1_act=args.L1_act, regression=regression,
l1_value=args.L1, L1_act=args.L1_act, regression=regression,
num_covariates=args.num_covariates,
mask_order=args.mask_order if hasattr(args, 'mask_order') else None)

Expand All @@ -294,7 +296,7 @@ def load_trained_network(args):
get_network needs the following inputs
args.inputsize
args.l1_value
args.L1
args.L1_act
args.datapath
args.genotype_path
Expand Down
17 changes: 15 additions & 2 deletions GenNet_utils/Utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def save_train_arguments(args, filename="train_args.json"):

# Convert args to a dictionary, taking care of non-serializable types if necessary
args_dict = vars(args) # Convert the Namespace to a dictionary
with open(filename, 'w') as file:
json.dump(args_dict,args.resultpath + file, indent=4)
with open(args.resultpath + filename, 'w') as file:
json.dump(args_dict, file, cls=NumpyEncoder, indent=4)



Expand All @@ -251,3 +251,16 @@ def load_train_arguments(args, filename="train_args.json"):

return args


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32,
np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)): # Handle numpy arrays
return obj.tolist()
return super(NumpyEncoder, self).default(obj)
2 changes: 1 addition & 1 deletion interpretation/NID.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def GenNet_pairwise_interactions_simplified_mask(w_input, w_later, mask):
interaction_ranking = []
list_of_combinations = []
for gene_id in range(mask.shape[1]):
neuron_combinations=list(itertools.combinations(mask.row[genemask.col == gene_id], 2 ))
neuron_combinations=list(itertools.combinations(mask.row[mask.col == gene_id], 2 ))

for candidate in neuron_combinations:
strength = (np.minimum(w_input[candidate[0]], w_input[candidate[1]])*w_later[gene_id]).sum()
Expand Down

0 comments on commit 5268526

Please sign in to comment.