Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA Kernel decision trees with user-defined kernel dictionaries #302

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions examples/manifold/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _manifold_examples:

Manifold learning with Decision-trees
-------------------------------------

Examples demonstrating manifold learning using random forest variants.
106 changes: 106 additions & 0 deletions examples/manifold/plot_kernel_decision_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
======================================
Custom Kernel Decision Tree Classifier
======================================

This example shows how to build a manifold oblique decision tree classifier using
a custom set of user-defined kernel/filter library, such as the Gaussian, or Gabor
kernels.

The example demonstrates superior performance on a 2D dataset with structured images
as samples. The dataset is the downsampled MNIST dataset, where each sample is a
28x28 image. The dataset is downsampled to 14x14, and then flattened to a 196
dimensional vector. The dataset is then split into a training and testing set.

See :ref:`sphx_glr_auto_examples_plot_projection_matrices` for more information on
projection matrices and the way they can be sampled.
"""

import matplotlib.pyplot as plt

# %%
# Importing the necessary modules
from sklearn.datasets import fetch_openml
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from treeple.tree import KernelDecisionTreeClassifier

# %%
# Load the Dataset
# ----------------
# We need to load the dataset and split it into training and testing sets.

# Load the dataset
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)

# Downsample the dataset
X = X.reshape((-1, 28, 28))
X = X[:, ::2, ::2]
X = X.reshape((-1, 196))

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# %%
# Setting up the Custom Kernel Decision Tree Model
# -------------------------------------------------
# To set up the custom kernel decision tree model, we need to define typical hyperparameters
# for the decision tree classifier, such as the maximum depth of the tree and the minimum
# number of samples required to split an internal node. For the Kernel Decision tree model,
# we also need to define the kernel function and its parameters.

max_depth = 10
min_samples_split = 2

# Next, we define the hyperparameters for the custom kernels that we will use.
# For example, if we want to use a Gaussian kernel with a sigma of 1.0 and a size of 3x3:
kernel_function = "gaussian"
kernel_params = {"sigma": 1.0, "size": (3, 3)}

# We can then fit the custom kernel decision tree model to the training set:
clf = KernelDecisionTreeClassifier(
max_depth=max_depth,
min_samples_split=min_samples_split,
data_dims=(28, 28),
min_patch_dims=(1, 1),
max_patch_dims=(14, 14),
dim_contiguous=(True, True),
boundary=None,
n_classes=10,
kernel_function=kernel_function,
n_kernels=500,
store_kernel_library=True,
)

# Fit the decision tree classifier using the custom kernel
clf.fit(X_train, y_train)

# %%
# Evaluating the Custom Kernel Decision Tree Model
# ------------------------------------------------
# To evaluate the custom kernel decision tree model, we can use the testing set.
# We can also inspect the important kernels that the tree selected.

# Predict the labels for the testing set
y_pred = clf.predict(X_test)

# Compute the accuracy score
accuracy = accuracy_score(y_test, y_pred)

print(f"Kernel decision tree model obtained an accuracy of {accuracy} on MNIST.")

# Get the important kernels from the decision tree classifier
important_kernels = clf.kernel_arr_
kernel_dims = clf.kernel_dims_
kernel_params = clf.kernel_params_
kernel_library = clf.kernel_library_

# Plot the important kernels
fig, axes = plt.subplots(
nrows=len(important_kernels), ncols=1, figsize=(6, 4 * len(important_kernels))
)
for i, kernel in enumerate(important_kernels):
axes[i].imshow(kernel, cmap="gray")
axes[i].set_title("Kernel {}".format(i + 1))
plt.show()
77 changes: 77 additions & 0 deletions examples/splitters/plot_kernel_splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
====================
Random Kernel Splits
====================

This example shows how to build a manifold oblique decision tree classifier using
a custom set of user-defined kernel/filter library, such as the Gaussian, or Gabor
kernels.

The example demonstrates superior performance on a 2D dataset with structured images
as samples. The dataset is the downsampled MNIST dataset, where each sample is a
28x28 image. The dataset is downsampled to 14x14, and then flattened to a 196
dimensional vector. The dataset is then split into a training and testing set.

See :ref:`sphx_glr_auto_examples_plot_projection_matrices` for more information on
projection matrices and the way they can be sampled.
"""

import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csr_matrix

from treeple.tree.manifold._kernel_splitter import Kernel2D

# %%
# Create a synthetic image
image_height, image_width = 50, 50
image = np.random.rand(image_height, image_width).astype(np.float32)

# Generate a Gaussian kernel (example)
kernel_size = 7
x = np.linspace(-2, 2, kernel_size)
y = np.linspace(-2, 2, kernel_size)
x, y = np.meshgrid(x, y)
kernel = np.exp(-(x**2 + y**2))
kernel = kernel / kernel.sum() # Normalize the kernel

# Vectorize and create a sparse CSR matrix
kernel_vector = kernel.flatten().astype(np.float32)
kernel_indices = np.arange(kernel_vector.size)
kernel_indptr = np.array([0, kernel_vector.size])
kernel_csr = csr_matrix(
(kernel_vector, kernel_indices, kernel_indptr), shape=(1, kernel_vector.size)
)

# %%
# Initialize the Kernel2D class
kernel_sizes = np.array([kernel_size], dtype=np.intp)
random_state = np.random.RandomState(42)
print(kernel_csr.dtype, kernel_sizes.dtype, np.intp)
kernel_2d = Kernel2D(kernel_csr, kernel_sizes, random_state)

# Apply the kernel to the image
result_value = kernel_2d.apply_kernel_py(image, 0)

# %%
# Plot the original image, kernel, and result
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(image, cmap="gray")
axs[0].set_title("Original Image")

axs[1].imshow(kernel, cmap="viridis")
axs[1].set_title("Gaussian Kernel")

# Highlight the region where the kernel was applied
start_x, start_y = random_state.randint(0, image_width - kernel_size + 1), random_state.randint(
0, image_height - kernel_size + 1
)
image_with_kernel = image.copy()
image_with_kernel[start_y : start_y + kernel_size, start_x : start_x + kernel_size] *= kernel

axs[2].imshow(image_with_kernel, cmap="gray")
axs[2].set_title(f"Result: {result_value:.4f}")

plt.tight_layout()
plt.show()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ ignore_missing_imports = true
no_site_packages = true
exclude = [
'treeple/_lib/',
'treeple/_lib/*',
'benchmarks_nonasv/'
]

Expand Down
2 changes: 2 additions & 0 deletions treeple/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
UnsupervisedObliqueDecisionTree,
)
from ._honest_tree import HonestTreeClassifier
from ._kernel import KernelDecisionTreeClassifier
from ._multiview import MultiViewDecisionTreeClassifier
from ._neighbors import compute_forest_similarity_matrix

Expand All @@ -34,4 +35,5 @@
"ExtraTreeClassifier",
"ExtraTreeRegressor",
"MultiViewDecisionTreeClassifier",
"KernelDecisionTreeClassifier",
]
Loading
Loading