Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: MessageTemplate 禁止访问私有属性 #2509

Merged
merged 7 commits into from
Jan 4, 2024
38 changes: 35 additions & 3 deletions nonebot/internal/adapter/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@
overload,
)

from _string import formatter_field_name_split # type: ignore

if TYPE_CHECKING:
from .message import Message, MessageSegment

def formatter_field_name_split( # noqa: F811
field_name: str,
) -> Tuple[str, List[Tuple[bool, str]]]:
...


TM = TypeVar("TM", bound="Message")
TF = TypeVar("TF", str, "Message")

Expand All @@ -36,26 +44,37 @@ class MessageTemplate(Formatter, Generic[TF]):
参数:
template: 模板
factory: 消息类型工厂,默认为 `str`
private_getattr: 是否允许在模板中访问私有属性,默认为 `False`
"""

@overload
def __init__(
self: "MessageTemplate[str]", template: str, factory: Type[str] = str
self: "MessageTemplate[str]",
template: str,
factory: Type[str] = str,
private_getattr: bool = False,
) -> None:
...

@overload
def __init__(
self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
self: "MessageTemplate[TM]",
template: Union[str, TM],
factory: Type[TM],
private_getattr: bool = False,
) -> None:
...

def __init__(
self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str
self,
template: Union[str, TM],
factory: Union[Type[str], Type[TM]] = str,
private_getattr: bool = False,
) -> None:
self.template: TF = template # type: ignore
self.factory: Type[TF] = factory # type: ignore
self.format_specs: Dict[str, FormatSpecFunc] = {}
self.private_getattr = private_getattr

def __repr__(self) -> str:
return f"MessageTemplate({self.template!r}, factory={self.factory!r})"
Expand Down Expand Up @@ -167,6 +186,19 @@ def _vformat(

return functools.reduce(self._add, results), auto_arg_index

def get_field(
self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> Tuple[Any, Union[int, str]]:
first, rest = formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)

for is_attr, value in rest:
if not self.private_getattr and value.startswith("_"):
raise ValueError("Cannot access private attribute")
obj = getattr(obj, value) if is_attr else obj[value]

return obj, first

def format_field(self, value: Any, format_spec: str) -> Any:
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
if formatter is None and not issubclass(self.factory, str):
Expand Down
27 changes: 22 additions & 5 deletions tests/test_adapters/test_template.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from nonebot.adapters import MessageTemplate
from utils import FakeMessage, FakeMessageSegment, escape_text

Expand All @@ -15,12 +17,8 @@ def test_template_message():
def custom(input: str) -> str:
return f"{input}-custom!"

try:
with pytest.raises(ValueError, match="already exists"):
template.add_format_spec(custom)
except ValueError:
pass
else:
raise AssertionError("Should raise ValueError")

format_args = {
"a": "custom",
Expand Down Expand Up @@ -57,3 +55,22 @@ def test_message_injection():
message = template.format(name="[fake:image]")

assert message.extract_plain_text() == escape_text("[fake:image]Is Bad")


def test_malformed_template():
positive_template = FakeMessage.template("{a}{b}")
message = positive_template.format(a="a", b="b")
assert message.extract_plain_text() == "ab"

malformed_template = FakeMessage.template("{a.__init__}")
with pytest.raises(ValueError, match="private attribute"):
message = malformed_template.format(a="a")

malformed_template = FakeMessage.template("{a[__builtins__]}")
with pytest.raises(ValueError, match="private attribute"):
message = malformed_template.format(a=globals())

malformed_template = MessageTemplate(
"{a[__builtins__][__import__]}{b.__init__}", private_getattr=True
)
message = malformed_template.format(a=globals(), b="b")
Loading