-
-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from nidhaloff/dev
implemented hyperparameter search
- Loading branch information
Showing
8 changed files
with
132 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from igel import Igel | ||
|
||
""" | ||
The goal of igel is to use ML without writing code. Therefore, the right and simplest way to use igel is from terminal. | ||
You can run ` igel fit -dp path_to_dataset -yml path_to_yaml_file`. | ||
Alternatively, you can write code if you want. This example below demonstrates how to use igel if you want to write code. | ||
However, I suggest you try and use the igel CLI. Type igel -h in your terminal to know more. | ||
=============================================================================================================== | ||
This example fits a machine learning model on the indian-diabetes dataset | ||
- default model here is the neural network and the configuration are provided in neural-network.yaml file | ||
- You can switch to random forest by providing the random-forest.yaml as the config file in the parameters | ||
""" | ||
|
||
mock_fit_params = {'data_path': '../data/indian-diabetes/train-indians-diabetes.csv', | ||
'yaml_path': './igel.yaml', | ||
'cmd': 'fit'} | ||
|
||
Igel(**mock_fit_params) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# dataset operations | ||
dataset: | ||
type: csv | ||
split: # split options | ||
test_size: 0.2 # 0.2 means 20% for the test data, so 80% are automatically for training | ||
shuffle: True # whether to shuffle the data before/while splitting | ||
|
||
preprocess: | ||
scale: # scaling options | ||
method: standard # standardization will scale values to have a 0 mean and 1 standard deviation | you can also try minmax | ||
target: inputs # scale inputs. | other possible values: [outputs, all] # if you choose all then all values in the dataset will be scaled | ||
|
||
|
||
# model definition | ||
model: | ||
type: classification | ||
algorithm: RandomForest | ||
hyperparameter_search: | ||
method: random_search | ||
parameter_grid: | ||
max_depth: [6, 10] | ||
n_estimators: [100, 300] | ||
max_features: [auto, sqrt] | ||
arguments: | ||
cv: 5 | ||
refit: true | ||
return_train_score: false | ||
verbose: 0 | ||
|
||
# target you want to predict | ||
target: | ||
- sick |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
|
||
__author__ = "Nidhal Baccouri" | ||
__email__ = 'nidhalbacc@gmail.com' | ||
__version__ = '0.2.7' | ||
__version__ = '0.2.8' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV | ||
|
||
|
||
def hyperparameter_search(model, | ||
method, | ||
params, | ||
x_train, | ||
y_train, | ||
**kwargs): | ||
|
||
search = None | ||
if method == 'grid_search': | ||
search = GridSearchCV(model, | ||
params, | ||
**kwargs) | ||
|
||
elif method == 'random_search': | ||
search = RandomizedSearchCV(model, | ||
params, | ||
**kwargs) | ||
else: | ||
raise Exception("hyperparameter method must be grid_search or random_search") | ||
|
||
search.fit(x_train, y_train) | ||
return search.best_estimator_, search.best_score_, search.best_params_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[bumpversion] | ||
current_version = 0.1.8 | ||
current_version = 0.2.8 | ||
commit = True | ||
tag = True | ||
|
||
|