Skip to content

Commit

Permalink
Add tests for token escaper class (#69)
Browse files Browse the repository at this point in the history
Adds a set of unit tests on both the underlying token escaping class as
well as the `Tag` filterable fields that utilize it.

---------

Co-authored-by: Sam Partee <sam.partee@redis.com>
  • Loading branch information
tylerhutcherson and Sam Partee authored Oct 20, 2023
1 parent 6345cc1 commit 4e3de2d
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 35 deletions.
2 changes: 1 addition & 1 deletion redisvl/query/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union

from redisvl.utils.utils import TokenEscaper
from redisvl.utils.token_escaper import TokenEscaper

# disable mypy error for dunder method overrides
# mypy: disable-error-code="override"
Expand Down
30 changes: 30 additions & 0 deletions redisvl/utils/token_escaper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import re
from typing import Optional, Pattern


class TokenEscaper:
"""
Escape punctuation within an input string. Adapted from RedisOM Python.
"""

# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"

def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)

def escape(self, value: str) -> str:
if not isinstance(value, str):
raise TypeError(
f"Value must be a string object for token escaping, got type {type(value)}"
)

def escape_symbol(match):
value = match.group(0)
return f"\\{value}"

return self.escaped_chars_re.sub(escape_symbol, value)
26 changes: 1 addition & 25 deletions redisvl/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern
from typing import TYPE_CHECKING, Any, Dict, List

if TYPE_CHECKING:
from redis.commands.search.result import Result
Expand All @@ -8,29 +7,6 @@
import numpy as np


class TokenEscaper:
"""
Escape punctuation within an input string. Taken from RedisOM Python.
"""

# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"

def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)

def escape(self, value: str) -> str:
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"

return self.escaped_chars_re.sub(escape_symbol, value)


def make_dict(values: List[Any]):
# TODO make this a real function
i = 0
Expand Down
57 changes: 48 additions & 9 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,54 @@
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text


def test_tag_filter():
tf = Tag("tag_field") == ["tag1", "tag2"]
assert str(tf) == "@tag_field:{tag1|tag2}"

tf = Tag("tag_field") == "tag1"
assert str(tf) == "@tag_field:{tag1}"

tf = Tag("tag_field") != ["tag1", "tag2"]
assert str(tf) == "(-@tag_field:{tag1|tag2})"
# Test cases for various scenarios of tag usage, combinations, and their string representations.
@pytest.mark.parametrize(
"operation,tags,expected",
[
# Testing single tags
("==", "simpletag", "@tag_field:{simpletag}"),
(
"==",
"tag with space",
"@tag_field:{tag\\ with\\ space}",
), # Escaping spaces within quotes
(
"==",
"special$char",
"@tag_field:{special\\$char}",
), # Escaping a special character
("!=", "negated", "(-@tag_field:{negated})"),
# Testing multiple tags
("==", ["tag1", "tag2"], "@tag_field:{tag1|tag2}"),
(
"==",
["alpha", "beta with space", "gamma$special"],
"@tag_field:{alpha|beta\\ with\\ space|gamma\\$special}",
), # Multiple tags with spaces and special chars
("!=", ["tagA", "tagB"], "(-@tag_field:{tagA|tagB})"),
# Complex tag scenarios with special characters
("==", "weird:tag", "@tag_field:{weird\\:tag}"), # Tags with colon
("==", "tag&another", "@tag_field:{tag\\&another}"), # Tags with ampersand
# Escaping various special characters within tags
("==", "tag/with/slashes", "@tag_field:{tag\\/with\\/slashes}"),
(
"==",
["hypen-tag", "under_score", "dot.tag"],
"@tag_field:{hypen\\-tag|under_score|dot\\.tag}",
),
# ...additional unique cases as desired...
],
)
def test_tag_filter_varied(operation, tags, expected):
if operation == "==":
tf = Tag("tag_field") == tags
elif operation == "!=":
tf = Tag("tag_field") != tags
else:
raise ValueError(f"Unsupported operation: {operation}")

# Verify the string representation matches the expected RediSearch query part
assert str(tf) == expected


def test_numeric_filter():
Expand Down
130 changes: 130 additions & 0 deletions tests/test_token_escaper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pytest

from redisvl.utils.token_escaper import TokenEscaper


@pytest.fixture
def escaper():
return TokenEscaper()


@pytest.mark.parametrize(
("test_input,expected"),
[
(r"a [big] test.", r"a\ \[big\]\ test\."),
(r"hello, world!", r"hello\,\ world\!"),
(
r'special "quotes" (and parentheses)',
r"special\ \"quotes\"\ \(and\ parentheses\)",
),
(
r"& symbols, like * and ?",
r"\&\ symbols\,\ like\ \*\ and\ ?",
), # TODO: question marks are not caught?
# underscores are ignored
(r"-dashes_and_underscores-", r"\-dashes_and_underscores\-"),
],
ids=[
"brackets",
"commas",
"quotes",
"symbols",
"underscores"
]
)
def test_escape_text_chars(escaper, test_input, expected):
assert escaper.escape(test_input) == expected


@pytest.mark.parametrize(
("test_input,expected"),
[
# Simple tags
("user:name", r"user\:name"),
("123#comment", r"123\#comment"),
("hyphen-separated", r"hyphen\-separated"),
# Tags with special characters
("price$", r"price\$"),
("super*star", r"super\*star"),
("tag&value", r"tag\&value"),
("@username", r"\@username"),
# Space-containing tags often used in search scenarios
("San Francisco", r"San\ Francisco"),
("New Zealand", r"New\ Zealand"),
# Multi-special-character tags
("complex/tag:value", r"complex\/tag\:value"),
("$special$tag$", r"\$special\$tag\$"),
("tag-with-hyphen", r"tag\-with\-hyphen"),
# Tags with less common, but legal characters
("_underscore_", r"_underscore_"),
("dot.tag", r"dot\.tag"),
# ("pipe|tag", r"pipe\|tag"), #TODO - pipes are not caught?
# More edge cases with special characters
("(parentheses)", r"\(parentheses\)"),
("[brackets]", r"\[brackets\]"),
("{braces}", r"\{braces\}"),
# ("question?mark", r"question\?mark"), #TODO - question marks are not caught?
# Unicode characters in tags
("你好", r"你好"), # Assuming non-Latin characters don't need escaping
("emoji:😊", r"emoji\:😊"),
# ...other cases as needed...
],
ids=[
":",
"#",
"-",
"$",
"*",
"&",
"@",
"space",
"space-2",
"complex",
"special",
"hyphen",
"underscore",
"dot",
"parentheses",
"brackets",
"braces",
"non-latin",
"emoji"
]
)
def test_escape_tag_like_values(escaper, test_input, expected):
assert escaper.escape(test_input) == expected


@pytest.mark.parametrize("test_input", [123, 45.67, None, [], {}])
def test_escape_non_string_input(escaper, test_input):
with pytest.raises(TypeError):
escaper.escape(test_input)


@pytest.mark.parametrize(
"test_input,expected",
[
# ('你好,世界!', r'你好\,世界\!'), # TODO - non latin chars?
("😊 ❤️ 👍", r"😊\ ❤️\ 👍"),
# ...other cases as needed...
],
ids=[
"emoji"
]
)
def test_escape_unicode_characters(escaper, test_input, expected):
assert escaper.escape(test_input) == expected


def test_escape_empty_string(escaper):
assert escaper.escape("") == ""


def test_escape_long_string(escaper):
# Construct a very long string
long_str = "a," * 1000 # This creates a string "a,a,a,a,...a,"
expected = r"a\," * 1000 # Expected escaped string

# Use pytest's benchmark fixture to check performance
escaped = escaper.escape(long_str)
assert escaped == expected

0 comments on commit 4e3de2d

Please sign in to comment.