From d8df684c9f20e557d3364094bf6afbc963b87156 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Mon, 18 Nov 2024 17:00:45 -0800 Subject: [PATCH] Do not use mutable as default arg for plugins (#524) --- docs/source/viztracer.rst | 6 +++--- src/viztracer/vizplugin.py | 31 ++++++++++++++++--------------- src/viztracer/viztracer.py | 2 +- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/docs/source/viztracer.rst b/docs/source/viztracer.rst index 222f1265..1cc4cf22 100644 --- a/docs/source/viztracer.rst +++ b/docs/source/viztracer.rst @@ -29,7 +29,7 @@ VizTracer sanitize_function_name=False,\ process_name=None,\ output_file="result.json",\ - plugins=[]) + plugins=None) .. py:attribute:: tracer_entries :type: int @@ -339,8 +339,8 @@ VizTracer viztracer -o .. py:attribute:: plugins - :type: Sequence[Union[VizPluginBase, str]] - :value: [] + :type: Optional[Sequence[Union[VizPluginBase, str]]] + :value: None List of plugins to use. diff --git a/src/viztracer/vizplugin.py b/src/viztracer/vizplugin.py index ff9f3b64..9a82700e 100644 --- a/src/viztracer/vizplugin.py +++ b/src/viztracer/vizplugin.py @@ -51,23 +51,24 @@ def message(self, m_type: str, payload: dict) -> dict: class VizPluginManager: - def __init__(self, tracer: "VizTracer", plugins: Sequence[Union[VizPluginBase, str]]): + def __init__(self, tracer: "VizTracer", plugins: Optional[Sequence[Union[VizPluginBase, str]]]): self._tracer = tracer self._plugins = [] - for plugin in plugins: - if isinstance(plugin, VizPluginBase): - plugin_instance = plugin - elif isinstance(plugin, str): - plugin_instance = self._get_plugin_from_string(plugin) - else: - raise TypeError("Invalid plugin!") - self._plugins.append(plugin_instance) - - support_version = plugin_instance.support_version() - if compare_version(support_version, __version__) > 0: - color_print("WARNING", "The plugin support version is higher than " - "viztracer version. Consider update your viztracer") - self._send_message(plugin_instance, "event", {"when": "initialize"}) + if plugins: + for plugin in plugins: + if isinstance(plugin, VizPluginBase): + plugin_instance = plugin + elif isinstance(plugin, str): + plugin_instance = self._get_plugin_from_string(plugin) + else: + raise TypeError("Invalid plugin!") + self._plugins.append(plugin_instance) + + support_version = plugin_instance.support_version() + if compare_version(support_version, __version__) > 0: + color_print("WARNING", "The plugin support version is higher than " + "viztracer version. Consider update your viztracer") + self._send_message(plugin_instance, "event", {"when": "initialize"}) def _get_plugin_from_string(self, plugin: str) -> VizPluginBase: args = plugin.split() diff --git a/src/viztracer/viztracer.py b/src/viztracer/viztracer.py index 28bc506e..06972e26 100644 --- a/src/viztracer/viztracer.py +++ b/src/viztracer/viztracer.py @@ -51,7 +51,7 @@ def __init__(self, sanitize_function_name: bool = False, process_name: Optional[str] = None, output_file: str = "result.json", - plugins: Sequence[Union[VizPluginBase, str]] = []) -> None: + plugins: Optional[Sequence[Union[VizPluginBase, str]]] = None) -> None: super().__init__(tracer_entries) # Members of C Tracer object