Skip to content

Commit

Permalink
Merge pull request #2 from ssec-jhu/scarliles/regression-benchmark
Browse files Browse the repository at this point in the history
added regression forest benchmark
  • Loading branch information
SamuelCarliles3 authored Apr 22, 2024
2 parents 775f0b7 + 7a70a0b commit d9ad68a
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion asv_benchmarks/benchmarks/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,58 @@
GradientBoostingClassifier,
HistGradientBoostingClassifier,
RandomForestClassifier,
RandomForestRegressor
)

from .common import Benchmark, Estimator, Predictor
from .datasets import (
_20newsgroups_highdim_dataset,
_20newsgroups_lowdim_dataset,
_synth_classification_dataset,
_synth_regression_dataset,
_synth_regression_sparse_dataset
)
from .utils import make_gen_classif_scorers
from .utils import make_gen_classif_scorers, make_gen_reg_scorers


class RandomForestRegressorBenchmark(Predictor, Estimator, Benchmark):
"""
Benchmarks for RandomForestRegressor.
"""

param_names = ["representation", "n_jobs"]
params = (["dense", "sparse"], Benchmark.n_jobs_vals)

def setup_cache(self):
super().setup_cache()

def make_data(self, params):
representation, n_jobs = params

if representation == "sparse":
data = _synth_regression_sparse_dataset()
else:
data = _synth_regression_dataset()

return data

def make_estimator(self, params):
representation, n_jobs = params

n_estimators = 500 if Benchmark.data_size == "large" else 100

estimator = RandomForestRegressor(
n_estimators=n_estimators,
min_samples_split=10,
max_features="log2",
n_jobs=n_jobs,
random_state=0,
)

return estimator

def make_scorers(self):
make_gen_reg_scorers(self)


class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark):
Expand Down

0 comments on commit d9ad68a

Please sign in to comment.