-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfit_NN.m
56 lines (50 loc) · 1.92 KB
/
fit_NN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
function [net,tr] = fit_NN(inputs, targets, model_no)
%% This is the function used to solve an Input-Output Fitting problem with a Neural Network
%
% Usage:
% [net, tr] = fit_NN(inputs, targets)
%
% Inputs:
% inputs: The Ndata x Dim_input matrix of model inputs;
% targets: The Ndata x Dim_output matrix of model outputs;
% model_no: Scalar value of the ANN architecture to use;
%
% Output:
% net: The trained ANN model function;
% tr: The structure of the ANN training statistics;
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Define the function:
% Create a series of different Fitting Network:
% Advice - Do not run more than 3 hidden layers
if model_no == 1
hidden_nodes_vec = [18]; % ANN architecture type 1
elseif model_no == 2
hidden_nodes_vec = [18, 9]; % ANN architecture type 2
elseif model_no == 3
hidden_nodes_vec = [27, 18, 9]; % ANN architecture type 3
elseif model_no == 4
hidden_nodes_vec = [64, 32, 16]; % ANN architecture type 4
elseif model_no == 5
hidden_nodes_vec = [32]; % ANN architecture type 5
elseif model_no == 6
hidden_nodes_vec = [64]; % ANN architecture type 6
elseif model_no == 7
hidden_nodes_vec = [64, 32, 8]; % ANN architecture type 7
elseif model_no == 8
hidden_nodes_vec = [400, 350, 300, 250, 200, 150, 100, 50]; % ANN architecture type 8
end
% Specify the type of ANN training and epochs:
net = feedforwardnet(hidden_nodes_vec); % Training type: Feed-forward net
net.trainParam.epochs = 1000;
% Set activation function for each hidden layer as RELU (Rectified Linear Unit activation):
for i = 1:length(hidden_nodes_vec)
net.layers{i}.transferFcn = 'poslin';
end
% Set up Division of Data for Training and Testing:
net.divideParam.trainRatio = 0.7;
net.divideParam.testRatio = 0.15;
net.divideParam.valRatio = 0.15;
% Train the Network:
[net,tr] = train(net,inputs',targets');
end