Skip to content

Commit

Permalink
Do not use mutable as default arg for plugins (#524)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaogaotiantian authored Nov 19, 2024
1 parent 7dda50f commit d8df684
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
6 changes: 3 additions & 3 deletions docs/source/viztracer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ VizTracer
sanitize_function_name=False,\
process_name=None,\
output_file="result.json",\
plugins=[])
plugins=None)
.. py:attribute:: tracer_entries
:type: int
Expand Down Expand Up @@ -339,8 +339,8 @@ VizTracer
viztracer -o <filepath>
.. py:attribute:: plugins
:type: Sequence[Union[VizPluginBase, str]]
:value: []
:type: Optional[Sequence[Union[VizPluginBase, str]]]
:value: None

List of plugins to use.

Expand Down
31 changes: 16 additions & 15 deletions src/viztracer/vizplugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,24 @@ def message(self, m_type: str, payload: dict) -> dict:


class VizPluginManager:
def __init__(self, tracer: "VizTracer", plugins: Sequence[Union[VizPluginBase, str]]):
def __init__(self, tracer: "VizTracer", plugins: Optional[Sequence[Union[VizPluginBase, str]]]):
self._tracer = tracer
self._plugins = []
for plugin in plugins:
if isinstance(plugin, VizPluginBase):
plugin_instance = plugin
elif isinstance(plugin, str):
plugin_instance = self._get_plugin_from_string(plugin)
else:
raise TypeError("Invalid plugin!")
self._plugins.append(plugin_instance)

support_version = plugin_instance.support_version()
if compare_version(support_version, __version__) > 0:
color_print("WARNING", "The plugin support version is higher than "
"viztracer version. Consider update your viztracer")
self._send_message(plugin_instance, "event", {"when": "initialize"})
if plugins:
for plugin in plugins:
if isinstance(plugin, VizPluginBase):
plugin_instance = plugin
elif isinstance(plugin, str):
plugin_instance = self._get_plugin_from_string(plugin)
else:
raise TypeError("Invalid plugin!")
self._plugins.append(plugin_instance)

support_version = plugin_instance.support_version()
if compare_version(support_version, __version__) > 0:
color_print("WARNING", "The plugin support version is higher than "
"viztracer version. Consider update your viztracer")
self._send_message(plugin_instance, "event", {"when": "initialize"})

def _get_plugin_from_string(self, plugin: str) -> VizPluginBase:
args = plugin.split()
Expand Down
2 changes: 1 addition & 1 deletion src/viztracer/viztracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
sanitize_function_name: bool = False,
process_name: Optional[str] = None,
output_file: str = "result.json",
plugins: Sequence[Union[VizPluginBase, str]] = []) -> None:
plugins: Optional[Sequence[Union[VizPluginBase, str]]] = None) -> None:
super().__init__(tracer_entries)

# Members of C Tracer object
Expand Down

0 comments on commit d8df684

Please sign in to comment.