-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Cython axis-aligned multi-view splitter (#129)
* Cython multiview --------- Signed-off-by: Adam Li <adam2392@gmail.com>
- Loading branch information
Showing
43 changed files
with
2,031 additions
and
257 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
name: "Release to PyPI" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.type }} | ||
cancel-in-progress: true | ||
|
||
on: | ||
workflow_run: | ||
workflows: ["build_and_test_slow"] | ||
types: | ||
- completed | ||
workflow_dispatch: | ||
|
||
env: | ||
INSTALLDIR: "build-install" | ||
CCACHE_DIR: "${{ github.workspace }}/.ccache" | ||
|
||
jobs: | ||
# release is ran when a release is made on Github | ||
release: | ||
name: Release | ||
runs-on: ubuntu-latest | ||
needs: [build_and_test_slow] | ||
if: startsWith(github.ref, 'refs/tags/') | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v4 | ||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4.6.1 | ||
with: | ||
python-version: 3.9 | ||
architecture: "x64" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --progress-bar off --upgrade pip setuptools wheel | ||
python -m pip install --progress-bar off build twine | ||
- name: Prepare environment | ||
run: | | ||
echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV | ||
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV | ||
- name: Download package distribution files | ||
uses: actions/download-artifact@v3 | ||
with: | ||
name: package | ||
path: dist | ||
# TODO: refactor scripts to generate release notes from `whats_new.rst` file instead | ||
# - name: Generate release notes | ||
# run: | | ||
# python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md | ||
- name: Publish package to PyPI | ||
run: | | ||
twine upload -u ${{ secrets.PYPI_USERNAME }} -p ${{ secrets.PYPI_PASSWORD }} dist/* | ||
- name: Publish GitHub release | ||
uses: softprops/action-gh-release@v1 | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
with: | ||
# body_path: ${{ github.workspace }}-RELEASE_NOTES.md | ||
prerelease: ${{ contains(env.TAG, 'rc') }} | ||
files: | | ||
dist/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
Examples | ||
-------- | ||
======== | ||
|
||
Examples demonstrating how to use scikit-tree algorithms. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _hyppo_examples: | ||
|
||
Hypothesis testing with decision trees | ||
-------------------------------------- | ||
|
||
Examples demonstrating how to use decision-trees for statistical hypothesis testing. |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _outlier_examples: | ||
|
||
Outlier-detection | ||
----------------- | ||
|
||
Examples concerning how to do outlier detection with decision trees. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _splitter_examples: | ||
|
||
Decision-tree splitters | ||
----------------------- | ||
|
||
Examples demonstrating different node-splitting strategies for decision trees. |
124 changes: 124 additions & 0 deletions
124
examples/splitters/plot_multiview_axis_aligned_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
""" | ||
================================================================================= | ||
Demonstrate and visualize a multi-view projection matrix for an axis-aligned tree | ||
================================================================================= | ||
This example shows how multi-view projection matrices are generated for a decision tree, | ||
specifically the :class:`sktree.tree.MultiViewDecisionTreeClassifier`. | ||
Multi-view projection matrices operate under the assumption that the input ``X`` array | ||
consists of multiple feature-sets that are groups of features important for predicting | ||
``y``. | ||
For details on how to use the hyperparameters related to the multi-view, see | ||
:class:`sktree.tree.MultiViewDecisionTreeClassifier`. | ||
""" | ||
|
||
# import modules | ||
# .. note:: We use a private Cython module here to demonstrate what the patches | ||
# look like. This is not part of the public API. The Cython module used | ||
# is just a Python wrapper for the underlying Cython code and is not the | ||
# same as the Cython splitter used in the actual implementation. | ||
# To use the actual splitter, one should use the public API for the | ||
# relevant tree/forests class. | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib.cm import ScalarMappable | ||
from matplotlib.colors import ListedColormap | ||
|
||
from sktree._lib.sklearn.tree._criterion import Gini | ||
from sktree.tree._oblique_splitter import MultiViewSplitterTester | ||
|
||
criterion = Gini(1, np.array((0, 1))) | ||
max_features = 5 | ||
min_samples_leaf = 1 | ||
min_weight_leaf = 0.0 | ||
random_state = np.random.RandomState(10) | ||
|
||
# we "simulate" three feature sets, with 3, 2 and 4 features respectively | ||
feature_set_ends = np.array([3, 5, 9], dtype=np.intp) | ||
n_feature_sets = len(feature_set_ends) | ||
|
||
feature_combinations = 1 | ||
monotonic_cst = None | ||
missing_value_feature_mask = None | ||
|
||
# initialize some dummy data | ||
X = np.repeat(np.arange(feature_set_ends[-1]).astype(np.float32), 5).reshape(5, -1) | ||
y = np.array([0, 0, 0, 1, 1]).reshape(-1, 1).astype(np.float64) | ||
sample_weight = np.ones(5) | ||
|
||
print("The shape of our dataset is: ", X.shape, y.shape, sample_weight.shape) | ||
|
||
# %% | ||
# Initialize the multi-view splitter | ||
# ---------------------------------- | ||
# The multi-view splitter is a Cython class that is initialized internally | ||
# in scikit-tree. However, we expose a Python tester object to demonstrate | ||
# how the splitter works in practice. | ||
# | ||
# .. warning:: Do not use this interface directly in practice. | ||
|
||
splitter = MultiViewSplitterTester( | ||
criterion, | ||
max_features, | ||
min_samples_leaf, | ||
min_weight_leaf, | ||
random_state, | ||
monotonic_cst, | ||
feature_combinations, | ||
feature_set_ends, | ||
n_feature_sets, | ||
) | ||
splitter.init_test(X, y, sample_weight, missing_value_feature_mask) | ||
|
||
# %% | ||
# Sample the projection matrix | ||
# ---------------------------- | ||
# The projection matrix is sampled by the splitter. The projection | ||
# matrix is a (max_features, n_features) matrix that selects which features of ``X`` | ||
# to define candidate split dimensions. The multi-view | ||
# splitter's projection matrix though samples from multiple feature sets, | ||
# which are aligned contiguously over the columns of ``X``. | ||
|
||
projection_matrix = splitter.sample_projection_matrix_py() | ||
print(projection_matrix) | ||
|
||
cmap = ListedColormap(["white", "green"][:n_feature_sets]) | ||
|
||
# Create a heatmap to visualize the indices | ||
fig, ax = plt.subplots(figsize=(6, 6)) | ||
|
||
ax.imshow( | ||
projection_matrix, cmap=cmap, aspect=feature_set_ends[-1] / max_features, interpolation="none" | ||
) | ||
ax.axvline(feature_set_ends[0] - 0.5, color="black", linewidth=1, label="Feature Sets") | ||
for iend in feature_set_ends[1:]: | ||
ax.axvline(iend - 0.5, color="black", linewidth=1) | ||
|
||
ax.set(title="Sampled Projection Matrix", xlabel="Feature Index", ylabel="Projection Vector Index") | ||
ax.set_xticks(np.arange(feature_set_ends[-1])) | ||
ax.set_yticks(np.arange(max_features)) | ||
ax.set_yticklabels(np.arange(max_features, dtype=int) + 1) | ||
ax.set_xticklabels(np.arange(feature_set_ends[-1], dtype=int) + 1) | ||
ax.legend() | ||
|
||
# Create a mappable object | ||
sm = ScalarMappable(cmap=cmap) | ||
sm.set_array([]) # You can set an empty array or values here | ||
|
||
# Create a color bar with labels for each feature set | ||
colorbar = fig.colorbar(sm, ax=ax, ticks=[0.25, 0.75], format="%d") | ||
colorbar.set_label("Projection Weight (I.e. Sampled Feature From a Feature Set)") | ||
colorbar.ax.set_yticklabels(["0", "1"]) | ||
|
||
plt.show() | ||
|
||
# %% | ||
# Discussion | ||
# ---------- | ||
# As we can see, the multi-view splitter samples split candidates uniformly across the feature sets. | ||
# In contrast, the normal splitter in :class:`sklearn.tree.DecisionTreeClassifier` samples | ||
# randomly across all ``n_features`` features because it is not aware of the multi-view structure. | ||
# This is the key difference between the two splitters. |
Oops, something went wrong.