-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for token escaper class (#69)
Adds a set of unit tests on both the underlying token escaping class as well as the `Tag` filterable fields that utilize it. --------- Co-authored-by: Sam Partee <sam.partee@redis.com>
- Loading branch information
1 parent
6345cc1
commit 4e3de2d
Showing
5 changed files
with
210 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import re | ||
from typing import Optional, Pattern | ||
|
||
|
||
class TokenEscaper: | ||
""" | ||
Escape punctuation within an input string. Adapted from RedisOM Python. | ||
""" | ||
|
||
# Characters that RediSearch requires us to escape during queries. | ||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization | ||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" | ||
|
||
def __init__(self, escape_chars_re: Optional[Pattern] = None): | ||
if escape_chars_re: | ||
self.escaped_chars_re = escape_chars_re | ||
else: | ||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) | ||
|
||
def escape(self, value: str) -> str: | ||
if not isinstance(value, str): | ||
raise TypeError( | ||
f"Value must be a string object for token escaping, got type {type(value)}" | ||
) | ||
|
||
def escape_symbol(match): | ||
value = match.group(0) | ||
return f"\\{value}" | ||
|
||
return self.escaped_chars_re.sub(escape_symbol, value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import pytest | ||
|
||
from redisvl.utils.token_escaper import TokenEscaper | ||
|
||
|
||
@pytest.fixture | ||
def escaper(): | ||
return TokenEscaper() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("test_input,expected"), | ||
[ | ||
(r"a [big] test.", r"a\ \[big\]\ test\."), | ||
(r"hello, world!", r"hello\,\ world\!"), | ||
( | ||
r'special "quotes" (and parentheses)', | ||
r"special\ \"quotes\"\ \(and\ parentheses\)", | ||
), | ||
( | ||
r"& symbols, like * and ?", | ||
r"\&\ symbols\,\ like\ \*\ and\ ?", | ||
), # TODO: question marks are not caught? | ||
# underscores are ignored | ||
(r"-dashes_and_underscores-", r"\-dashes_and_underscores\-"), | ||
], | ||
ids=[ | ||
"brackets", | ||
"commas", | ||
"quotes", | ||
"symbols", | ||
"underscores" | ||
] | ||
) | ||
def test_escape_text_chars(escaper, test_input, expected): | ||
assert escaper.escape(test_input) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("test_input,expected"), | ||
[ | ||
# Simple tags | ||
("user:name", r"user\:name"), | ||
("123#comment", r"123\#comment"), | ||
("hyphen-separated", r"hyphen\-separated"), | ||
# Tags with special characters | ||
("price$", r"price\$"), | ||
("super*star", r"super\*star"), | ||
("tag&value", r"tag\&value"), | ||
("@username", r"\@username"), | ||
# Space-containing tags often used in search scenarios | ||
("San Francisco", r"San\ Francisco"), | ||
("New Zealand", r"New\ Zealand"), | ||
# Multi-special-character tags | ||
("complex/tag:value", r"complex\/tag\:value"), | ||
("$special$tag$", r"\$special\$tag\$"), | ||
("tag-with-hyphen", r"tag\-with\-hyphen"), | ||
# Tags with less common, but legal characters | ||
("_underscore_", r"_underscore_"), | ||
("dot.tag", r"dot\.tag"), | ||
# ("pipe|tag", r"pipe\|tag"), #TODO - pipes are not caught? | ||
# More edge cases with special characters | ||
("(parentheses)", r"\(parentheses\)"), | ||
("[brackets]", r"\[brackets\]"), | ||
("{braces}", r"\{braces\}"), | ||
# ("question?mark", r"question\?mark"), #TODO - question marks are not caught? | ||
# Unicode characters in tags | ||
("你好", r"你好"), # Assuming non-Latin characters don't need escaping | ||
("emoji:😊", r"emoji\:😊"), | ||
# ...other cases as needed... | ||
], | ||
ids=[ | ||
":", | ||
"#", | ||
"-", | ||
"$", | ||
"*", | ||
"&", | ||
"@", | ||
"space", | ||
"space-2", | ||
"complex", | ||
"special", | ||
"hyphen", | ||
"underscore", | ||
"dot", | ||
"parentheses", | ||
"brackets", | ||
"braces", | ||
"non-latin", | ||
"emoji" | ||
] | ||
) | ||
def test_escape_tag_like_values(escaper, test_input, expected): | ||
assert escaper.escape(test_input) == expected | ||
|
||
|
||
@pytest.mark.parametrize("test_input", [123, 45.67, None, [], {}]) | ||
def test_escape_non_string_input(escaper, test_input): | ||
with pytest.raises(TypeError): | ||
escaper.escape(test_input) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_input,expected", | ||
[ | ||
# ('你好,世界!', r'你好\,世界\!'), # TODO - non latin chars? | ||
("😊 ❤️ 👍", r"😊\ ❤️\ 👍"), | ||
# ...other cases as needed... | ||
], | ||
ids=[ | ||
"emoji" | ||
] | ||
) | ||
def test_escape_unicode_characters(escaper, test_input, expected): | ||
assert escaper.escape(test_input) == expected | ||
|
||
|
||
def test_escape_empty_string(escaper): | ||
assert escaper.escape("") == "" | ||
|
||
|
||
def test_escape_long_string(escaper): | ||
# Construct a very long string | ||
long_str = "a," * 1000 # This creates a string "a,a,a,a,...a," | ||
expected = r"a\," * 1000 # Expected escaped string | ||
|
||
# Use pytest's benchmark fixture to check performance | ||
escaped = escaper.escape(long_str) | ||
assert escaped == expected |