-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfunc_TrainModel.m
126 lines (109 loc) · 3.19 KB
/
func_TrainModel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
function [Accuracy,ValErr,Momentum,L2Reg,InitLR,AUC,C,prb] = func_TrainModel(Parm)
%Train Model (training DeepInsight-FS)
% run Prepare_Data.m %TCGA RNA-seq data will be prepared using tSNE algorithm
% see Prepare_Data.m for details
curr_dir = pwd;
addpath(curr_dir);
Parm.fid=fopen('Results.txt','a+');
Parm.UsePreviousModel = 0;
%Parm.net=net;
%Parm.ObjcFcMeasure=ObjFcMeasure;
%if nargin<3
% Parm.ObjFcMeasure='accuracy'
%else
% Parm.ObjFcMeasure=ObjFcMeasure;
%end
[upm,Mo,Init,Reg] = func_UsePreviousModel(Parm.UsePrevModel); %UsePrevMod 'y for yes and 'n' for no
Parm.UsePreviousModel = upm;
if upm==0
if Parm.MaxObj~=1
Parm.Momentum = Mo; Parm.InitialLearnRate = Init; Parm.L2Regularization = Reg;
end
end
if Parm.Norm==2
[model, Norm] = DeepInsight_train_CAM(Parm,2); % if you want to use Norm 2 only
elseif Parm.Norm==1
[model, Norm] = DeepInsight_train_CAM(Parm,1);
else
[model, Norm] = DeepInsight_train_CAM(Parm); % best norm will be determined by the algorithm
end
model.Norm=Norm;
save('model.mat','-struct','model','-v7.3');
if Norm==1
Data = load('Out1.mat');
else
Data = load('Out2.mat');
end
if isfield(Data,'XTest')==0
Test_Empty=1;
else
if isempty(Data.XTest)==1
Test_Empty=1;
else
Test_Empty=0;
end
end
if size(Data.XTrain,3)==1
if Test_Empty==0
Data.XTest = cat(3,Data.XTest,Data.XTest,Data.XTest);
end
if Parm.ValidRatio>0
Data.XValidation = cat(3,Data.XValidation,Data.XValidation,Data.XValidation);
end
elseif size(Data.XTrain,3)==2
if Test_Empty==0
Data.XTest = cat(3,Data.XTest(:,:,1,:),Data.XTest(:,:,2,:),Data.XTest(:,:,1,:));
end
if Parm.ValidRatio>0
Data.XValidation = cat(3,Data.XValidation(:,:,1,:),Data.XValidation(:,:,2,:),Data.XValidation(:,:,1,:));
end
end
Data = rmfield(Data,'XTrain');
Data = rmfield(Data,'YTrain');
%Data = rmfield(Data,'XValidation');
%Data = rmfield(Data,'YValidation');
if Test_Empty==0
[Accuracy,AUC,C,prob_test] = DeepInsight_test_CAM(Data,model);
prb.test=prob_test;
prb.YTest=Data.YTest;
end
% NOTE: AUC is for two class problem only, otherwise its value would be 'NaN
%find validation probabilities
if Parm.ValidRatio>0
Data.XTest = Data.XValidation;
Data.YTest = Data.YValidation;
[Accuracy_val,AUC_val,C_val,prob_val] = DeepInsight_test_CAM(Data,model);
prb.val=prob_val;
prb.YValidation=Data.YValidation;
if Test_Empty==1
fprintf('\nNOTE: Test set is NOT available!\n');
fprintf('Performance measures are for Validation SET\n');
Accuracy=Accuracy_val;
AUC=AUC_val;
C=C_val;
end
else
Auccuracy_val=[]; AUC_val=[]; C_val=[]; prb.val=[];prb.YValidation=[];
end
% %find train probabilities
% Data.XTest = Data.XTrain;
% Data.YTest = Data.YTrain;
% [Accuracy_train,AUC_train,C_train,prob_train] = DeepInsight_test_CAM(Data,model);
% prb.train=prob_train;
% prb.YTrain=Data.YTrain;
fclose(Parm.fid);
ValErr = model.valError;
cd DeepResults
f=load(model.fileName);
cd ..
warning off;
if isfield(struct(f.options),'Momentum')==1
Momentum = f.options.Momentum;
else
Momentum=0;
end
L2Reg = f.options.L2Regularization;
InitLR = f.options.InitialLearnRate;
if isempty(AUC)==1
AUC=nan;
end