Skip to content

Commit

Permalink
Merge pull request #2 from edahelsinki/threading
Browse files Browse the repository at this point in the history
Threading
  • Loading branch information
Aggrathon authored Mar 24, 2022
2 parents 2433f8c + fdccc21 commit fd17836
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 54 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/python-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
name: Test Python Package

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = slise
version = 2.0.0
version = 2.1.0
author = Anton Björklund
author_email = anton.bjorklund@helsinki.fi
description = The SLISE algorithm for robust regression and explanations of black box models
Expand Down Expand Up @@ -35,3 +35,6 @@ install_requires =
exclude =
examples
tests

[options.extras_require]
tbb = tbb
15 changes: 13 additions & 2 deletions slise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,27 @@
constraints used to generate the data, e.g., the laws of physics).
More in-depth details about the algorithm can be found in the paper:
More in-depth details about the algorithm can be found in the papers:
Björklund A., Henelius A., Oikarinen E., Kallonen K., Puolamäki K.
Sparse Robust Regression for Explaining Classifiers.
Discovery Science (DS 2019).
Lecture Notes in Computer Science, vol 11828, Springer.
https://doi.org/10.1007/978-3-030-33778-0_27
Björklund A., Henelius A., Oikarinen E., Kallonen K., Puolamäki K.
Robust regression via error tolerance.
Data Mining and Knowledge Discovery (2022).
https://doi.org/10.1007/s10618-022-00819-2
"""

from slise.slise import SliseRegression, regression, SliseExplainer, explain
from slise.slise import (
SliseRegression,
regression,
SliseExplainer,
explain,
SliseWarning,
)
from slise.utils import limited_logit as logit
from slise.data import normalise_robust
29 changes: 28 additions & 1 deletion slise/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
from lbfgs import LBFGSError, fmin_lbfgs
from numba import jit
from numba import jit, get_num_threads, set_num_threads, threading_layer
from scipy.optimize import brentq

from slise.utils import (
Expand Down Expand Up @@ -375,6 +375,33 @@ def debug_log(
)


def set_threads(num: int = -1) -> int:
"""Set the number of numba threads
Args:
num (int, optional): The number of threads. Defaults to -1.
Returns:
int: The old number of theads (or -1 if unchanged).
"""
if num > 0:
old = get_num_threads()
set_num_threads(num)
return old
return -1


def check_threading_layer():
"""Check which numba threading_layer is active, and warn if it is "workqueue".
"""
loss_residuals(np.ones(1), np.ones(1), 1)
if threading_layer() == "workqueue":
warn(
'Using `numba.threading_layer()=="workqueue"` can be devastatingly slow! See https://numba.pydata.org/numba-doc/latest/user/threading-layer.html for alternatives.',
SliseWarning,
)


def graduated_optimisation(
alpha: np.ndarray,
X: np.ndarray,
Expand Down
Loading

0 comments on commit fd17836

Please sign in to comment.