Skip to content

Commit

Permalink
Fix filter logic bugs (#1753)
Browse files Browse the repository at this point in the history
* Add tests on integers in filtering

* Metadata.<> fix

* Check in incremental work
  • Loading branch information
NolanTrem authored Jan 3, 2025
1 parent 670b21c commit 9af89c4
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 169 deletions.
132 changes: 129 additions & 3 deletions js/sdk/__tests__/DocumentsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { r2rClient } from "../src/index";
import { describe, test, beforeAll, expect, afterAll } from "@jest/globals";
import { assert } from "console";
import fs from "fs";
import path from "path";

Expand All @@ -9,6 +10,7 @@ const TEST_OUTPUT_DIR = path.join(__dirname, "test-output");
/**
* marmeladov.txt will have an id of 83ef5342-4275-5b75-92d6-692fa32f8523
* The untitled document will have an id of 5556836e-a51c-57c7-916a-de76c79df2b6
* The default collection id is 122fdf6a-e116-546b-a8f6-e4cb2e2c0a09
*/
describe("r2rClient V3 Documents Integration Tests", () => {
let client: r2rClient;
Expand All @@ -35,7 +37,7 @@ describe("r2rClient V3 Documents Integration Tests", () => {
test("Create document with file path", async () => {
const response = await client.documents.create({
file: { path: "examples/data/marmeladov.txt", name: "marmeladov.txt" },
metadata: { title: "marmeladov.txt" },
metadata: { title: "marmeladov.txt", numericId: 123 },
});

expect(response.results.documentId).toBeDefined();
Expand All @@ -45,7 +47,7 @@ describe("r2rClient V3 Documents Integration Tests", () => {
test("Create document with content", async () => {
const response = await client.documents.create({
raw_text: "This is a test document",
metadata: { title: "Test Document" },
metadata: { title: "Test Document", numericId: 456 },
});

expect(response.results.documentId).toBeDefined();
Expand Down Expand Up @@ -175,7 +177,131 @@ describe("r2rClient V3 Documents Integration Tests", () => {
).rejects.toThrow(/Only one of file, raw_text, or chunks may be provided/);
});

test("Delete Raskolnikov.txt", async () => {
test("Search with $lte filter should only return documents with numericId <= 200", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
numericId: { $lte: 200 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId <= 200,
),
).toBe(true);
});

test("Search with $gte filter should only return documents with metadata.numericId >= 400", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
"metadata.numericId": { $gte: 400 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId >= 400,
),
).toBe(true);
});

test("Search with $eq filter should only return exact matches", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
numericId: { $eq: 123 },
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every(
(result) => result.metadata?.numericId === 123,
),
).toBe(true);
});

test("Search with range filter should return documents within range", async () => {
const response = await client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
"metadata.numericId": {
$gte: 500,
},
},
},
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(
response.results.chunkSearchResults.every((result) => {
const numericId = result.metadata?.numericId;
return numericId >= 100 && numericId <= 500;
}),
).toBe(true);
});

test("Search without filters should return both documents", async () => {
const response = await client.retrieval.search({
query: "Test query",
});

expect(response.results.chunkSearchResults).toBeDefined();
expect(response.results.chunkSearchResults.length).toBeGreaterThan(0);

const numericIds = response.results.chunkSearchResults.map((result) => {
return result.metadata?.numericId || result.metadata?.numericid;
});

expect(numericIds.filter((id) => id !== undefined)).toContain(123);
expect(numericIds.filter((id) => id !== undefined)).toContain(456);
});

// test("Filter on collection_id", async () => {
// const response = await client.retrieval.search({
// query: "Test query",
// searchSettings: {
// filters: {
// collection_ids: {
// $in: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"],
// },
// },
// },
// });
// expect(response.results.chunkSearchResults).toBeDefined();
// expect(response.results.chunkSearchResults.length).toBeGreaterThan(0);
// expect(response.results.chunkSearchResults[0].collectionIds).toContain(
// "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09",
// );
// });

test("Filter on non-existant column should return empty", async () => {
const response = await expect(
client.retrieval.search({
query: "Test query",
searchSettings: {
filters: {
nonExistentColumn: {
$eq: ["122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"],
},
},
},
}),
);
});

test("Delete marmeladov.txt", async () => {
const response = await client.documents.delete({
id: "83ef5342-4275-5b75-92d6-692fa32f8523",
});
Expand Down
104 changes: 67 additions & 37 deletions py/core/database/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def __init__(
else:
self.top_level_columns = set(top_level_columns)
self.json_column = json_column
self.params: list[Any] = (
params # params are mutated during construction
)
self.params: list[Any] = params # mutated during construction
self.mode = mode

def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
Expand Down Expand Up @@ -171,9 +169,8 @@ def _build_expression(self, expr: FilterExpression) -> str:
@staticmethod
def _psql_quote_literal(value: str) -> str:
"""
Safely quote a string literal for PostgreSQL to prevent SQL injection.
This is a simple implementation - in production, you should use proper parameterization
or your database driver's quoting functions.
Simple quoting for demonstration. In production, use parameterized queries or
your DB driver's quoting function instead.
"""
return "'" + value.replace("'", "''") + "'"

Expand All @@ -183,31 +180,81 @@ def _build_condition(self, cond: FilterCondition) -> str:
op = cond.operator
val = cond.value

# Handle special logic for collection_id
# 1. If the filter references "parent_id", handle it as a single-UUID column for graphs:
if key == "parent_id":
return self._build_parent_id_condition(op, val)

# 2. If the filter references "collection_id", handle it as an array column (chunks)
if key == "collection_id":
return self._build_collection_id_condition(op, val)

# 3. Otherwise, decide if it's top-level or metadata:
if field_is_metadata:
return self._build_metadata_condition(key, op, val)
else:
return self._build_column_condition(key, op, val)

def _build_parent_id_condition(self, op: str, val: Any) -> str:
"""
For 'graphs' tables, parent_id is a single UUID (not an array).
We handle the same ops but in a simpler, single-UUID manner.
"""
param_idx = len(self.params) + 1

if op == "$eq":
if not isinstance(val, str):
raise FilterError(
"$eq for parent_id expects a single UUID string"
)
self.params.append(val)
return f"parent_id = ${param_idx}::uuid"

elif op == "$ne":
if not isinstance(val, str):
raise FilterError(
"$ne for parent_id expects a single UUID string"
)
self.params.append(val)
return f"parent_id != ${param_idx}::uuid"

elif op == "$in":
# A list of UUIDs, any of which might match
if not isinstance(val, list):
raise FilterError(
"$in for parent_id expects a list of UUID strings"
)
self.params.append(val)
return f"parent_id = ANY(${param_idx}::uuid[])"

elif op == "$nin":
# A list of UUIDs, none of which may match
if not isinstance(val, list):
raise FilterError(
"$nin for parent_id expects a list of UUID strings"
)
self.params.append(val)
return f"parent_id != ALL(${param_idx}::uuid[])"

else:
# You could add more (like $gt, $lt, etc.) if your schema wants them
raise FilterError(f"Unsupported operator {op} for parent_id")

def _build_collection_id_condition(self, op: str, val: Any) -> str:
"""
For the 'chunks' table, collection_ids is an array of UUIDs.
This logic stays exactly as you had it.
"""
param_idx = len(self.params) + 1

# Handle operations
if op == "$eq":
# Expect a single UUID, ensure val is a string
if not isinstance(val, str):
raise FilterError(
"$eq for collection_id expects a single UUID string"
)
self.params.append(val)
# Check if val is in the collection_ids array
return f"${param_idx}::uuid = ANY(collection_ids)"

elif op == "$ne":
# Not equal means val is not in collection_ids
if not isinstance(val, str):
raise FilterError(
"$ne for collection_id expects a single UUID string"
Expand All @@ -216,31 +263,25 @@ def _build_collection_id_condition(self, op: str, val: Any) -> str:
return f"NOT (${param_idx}::uuid = ANY(collection_ids))"

elif op == "$in":
# Expect a list of UUIDs, any of which may match
if not isinstance(val, list):
raise FilterError(
"$in for collection_id expects a list of UUID strings"
)
self.params.append(val)
# Use overlap to check if any of the given IDs are in collection_ids
return f"collection_ids && ${param_idx}::uuid[]"

elif op == "$nin":
# None of the given UUIDs should be in collection_ids
if not isinstance(val, list):
raise FilterError(
"$nin for collection_id expects a list of UUID strings"
)
self.params.append(val)
# Negate overlap condition
return f"NOT (collection_ids && ${param_idx}::uuid[])"

elif op == "$contains":
# If someone tries "$contains" with a single collection_id, we can check if collection_ids fully contain it
# Usually $contains might mean we want to see if collection_ids contain a certain element.
# That's basically $eq logic. For a single value:
if isinstance(val, str):
self.params.append([val]) # Array of one element
# single string -> array with one element
self.params.append([val])
return f"collection_ids @> ${param_idx}::uuid[]"
elif isinstance(val, list):
self.params.append(val)
Expand Down Expand Up @@ -278,7 +319,6 @@ def _build_column_condition(self, col: str, op: str, val: Any) -> str:
self.params.append(val)
return f"{col} @> ${param_idx}"
elif op == "$any":
# If col == "collection_ids" handle special case
if col == "collection_ids":
self.params.append(f"%{val}%")
return f"array_to_string({col}, ',') LIKE ${param_idx}"
Expand All @@ -296,8 +336,7 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
json_col = self.json_column

# Strip "metadata." prefix if present
if key.startswith("metadata."):
key = key[len("metadata.") :]
key = key.removeprefix("metadata.")

# Split on '.' to handle nested keys
parts = key.split(".")
Expand All @@ -310,41 +349,34 @@ def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
"$gte",
"$eq",
"$ne",
)
) and isinstance(val, (int, float, str))
if op == "$in" or op == "$contains" or isinstance(val, (list, dict)):
use_text_extraction = False

# Build the JSON path expression
if len(parts) == 1:
# Single part key
if use_text_extraction:
path_expr = f"{json_col}->>'{parts[0]}'"
else:
path_expr = f"{json_col}->'{parts[0]}'"
else:
# Multiple segments
inner_parts = parts[:-1]
last_part = parts[-1]
# Build chain for the inner parts
path_expr = json_col
for p in inner_parts:
for p in parts[:-1]:
path_expr += f"->'{p}'"
# Last part
last_part = parts[-1]
if use_text_extraction:
path_expr += f"->>'{last_part}'"
else:
path_expr += f"->'{last_part}'"

# Convert numeric values to strings for text comparison
def prepare_value(v):
if isinstance(v, (int, float)):
return str(v)
return v
return str(v) if isinstance(v, (int, float)) else v

# Now apply the operator logic
if op == "$eq":
if use_text_extraction:
self.params.append(prepare_value(val))
prepared_val = prepare_value(val)
self.params.append(prepared_val)
return f"{path_expr} = ${param_idx}"
else:
self.params.append(json.dumps(val))
Expand Down Expand Up @@ -372,7 +404,6 @@ def prepare_value(v):
if not isinstance(val, list):
raise FilterError("argument to $in filter must be a list")

# For regular scalar values, use ANY with text extraction
if use_text_extraction:
str_vals = [
str(v) if isinstance(v, (int, float)) else v for v in val
Expand Down Expand Up @@ -413,7 +444,6 @@ def apply_filters(
"""
Apply filters with consistent WHERE clause handling
"""

if not filters:
return "", params

Expand Down
Loading

0 comments on commit 9af89c4

Please sign in to comment.