Skip to content

Commit

Permalink
Add kmeans clustering based on ray
Browse files Browse the repository at this point in the history
This includes generally three steps:
1. materialize a document's embedding
2. initialize centroids randomly
2. iterate the kmeans process until converge, this is based on ray
   dataset map group and aggregate operators.

The result centroids could be used for downstream work.
  • Loading branch information
bohou-aryn committed Dec 30, 2024
1 parent a3ac64b commit a5e26ff
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
36 changes: 36 additions & 0 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from sycamore.plan_nodes import Node, Transform
from sycamore.transforms.augment_text import TextAugmentor
from sycamore.transforms.clustering import KMeans
from sycamore.transforms.embed import Embedder
from sycamore.transforms import DocumentStructure, Sort
from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor
Expand Down Expand Up @@ -915,6 +916,41 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilon: float = 1e-4):
"""
Apply kmeans over embedding field
Args:
K: the count of centroids
iterations: the max iteration runs before converge
init_mode: how the initial centroids are select
epsilon: the condition for determining if it's converged
Return a list of max K centroids
"""

def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}

embeddings = self.plan.execute().map(init_embedding).materialize()

initial_centroids = KMeans.init(embeddings, K, init_mode)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
return centroids

def clustering(self, centroids, cluster_field_name, **resource_args) -> "DocSet":
def cluster(doc: Document) -> Document:
idx = KMeans.closest(doc.embedding, centroids)
properties = doc.properties
properties[cluster_field_name] = idx
doc.properties = properties
return doc

from sycamore.transforms import Map

mapping = Map(self.plan, f=cluster, **resource_args)
return DocSet(self.context, mapping)

def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet":
"""
Applies the FlatMap transformation on the Docset.
Expand Down
65 changes: 65 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import ray.data

import sycamore
from sycamore.data import Document
from sycamore.transforms.clustering import KMeans


class TestKMeans:

def test_kmeans(self):
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)
assert len(centroids) == 3

def test_closest(self):
row = [[0, 0, 0, 0]]
centroids = [
[1, 1, 1, 1],
[2, 2, 2, 2],
[-1, -1, -1, -1],
]
assert KMeans.closest(row, centroids) == 0

def test_random(self):
points = np.random.uniform(0, 40, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = KMeans.random_init(embeddings, 10)
assert len(centroids) == 10

def test_converged(self):
last_ones = [[1.0, 1.0], [10.0, 10.0]]
next_ones = [[2.0, 2.0], [12.0, 12.0]]
assert KMeans.converged(last_ones, next_ones, 10).item() is True
assert KMeans.converged(last_ones, next_ones, 1).item() is False

def test_converge(self):
points = np.random.uniform(0, 10, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]]
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4)
assert len(new_centroids) == 2

def test_clustering(self):
np.random.seed(2024)
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)

clustered_docs = docset.clustering(centroids, "cluster").take_all()
ids = [doc.properties["cluster"] for doc in clustered_docs]
assert all(0 <= idx < 3 for idx in ids)
74 changes: 74 additions & 0 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random

import torch
from ray.data.aggregate import AggregateFn


class KMeans:

@staticmethod
def closest(row, centroids):
row = torch.Tensor([row])
centroids = torch.Tensor(centroids)
distance = torch.cdist(row, centroids)
idx = torch.argmin(distance)
return idx

@staticmethod
def converged(last_ones, next_ones, epsilon):
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones))
return len(last_ones) == torch.sum(distance < epsilon)

@staticmethod
def random_init(embeddings, K):
count = embeddings.count()
assert count > 0 and K < count
fraction = min(2 * K / count, 1.0)

candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()]
candidates.sort()
from itertools import groupby

uniques = [key for key, _ in groupby(candidates)]
assert len(uniques) >= K

centroids = random.sample(uniques, K)
return centroids

@staticmethod
def init(embeddings, K, init_mode):
if init_mode == "random":
return KMeans.random_init(embeddings, K)
else:
raise Exception("Unknown init mode")

@staticmethod
def update(embeddings, centroids, iterations, epsilon):
i = 0
d = len(centroids[0])

update_centroids = AggregateFn(
init=lambda v: ([0] * d, 0),
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1),
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]),
name="centroids",
)

while i < iterations:

def _find_cluster(row):
idx = KMeans.closest(row["vector"], centroids)
return {"vector": row["vector"], "cluster": idx}

aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take()
import numpy as np

new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated]

if KMeans.converged(centroids, new_centroids, epsilon):
return new_centroids
else:
i += 1
centroids = new_centroids

return centroids

0 comments on commit a5e26ff

Please sign in to comment.