From a0d0889e78086f41654841c72f86f127560f458d Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Fri, 17 Jan 2025 08:37:39 -0500 Subject: [PATCH] fix filters --- flopy/mf6/utils/codegen/__init__.py | 2 + flopy/mf6/utils/codegen/filters.py | 38 +++++++++++++++---- .../mf6/utils/codegen/templates/macros.jinja | 12 +++--- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/flopy/mf6/utils/codegen/__init__.py b/flopy/mf6/utils/codegen/__init__.py index 6cf8a1fb8..4485fde2f 100644 --- a/flopy/mf6/utils/codegen/__init__.py +++ b/flopy/mf6/utils/codegen/__init__.py @@ -32,6 +32,8 @@ def _get_template_env(): env.filters["init"] = Filters.init env.filters["untag"] = Filters.untag env.filters["type"] = Filters.type + env.filters["children"] = Filters.children + env.filters["default"] = Filters.default env.filters["safe_name"] = Filters.safe_name env.filters["escape_trailing_underscore"] = ( Filters.escape_trailing_underscore diff --git a/flopy/mf6/utils/codegen/filters.py b/flopy/mf6/utils/codegen/filters.py index 651464b32..323333ae6 100644 --- a/flopy/mf6/utils/codegen/filters.py +++ b/flopy/mf6/utils/codegen/filters.py @@ -125,7 +125,7 @@ def untag(var: dict) -> dict: """ name = var["name"] tagged = var.get("tagged", False) - fields = var.get("children", None) + fields = var.get("fields", None) if not fields: return var @@ -149,7 +149,7 @@ def untag(var: dict) -> dict: if keyword: fields.pop(keyword) - var["children"] = fields + var["fields"] = fields return var def type(var: dict) -> str: @@ -160,7 +160,7 @@ def type(var: dict) -> str: """ _type = var["type"] shape = var.get("shape", None) - children = var.get("children", None) + children = Filters.children(var) if children: if _type == "list": if len(children) == 1: @@ -178,19 +178,38 @@ def type(var: dict) -> str: return f"({children})" elif _type == "union": return " | ".join([v["name"] for v in children.values()]) - if shape: + elif shape: return f"[{_type}]" return var["type"] + def children(var: dict) -> Optional[dict]: + _type = var["type"] + items = var.get("items", None) + fields = var.get("fields", None) + choices = var.get("choices", None) + if items: + assert _type == "list" + return items + if fields: + assert _type == "record" + return fields + if choices: + assert _type == "union" + return choices + return None + + def default(var: dict) -> Any: + _default = var.get("default", None) + if _default: + return _default + return None + @pass_context def attrs(ctx, vars_) -> List[str]: """ Map the context's input variables to corresponding class attributes, where applicable. TODO: this should get much simpler if we can drop all the `ListTemplateGenerator`/`ArrayTemplateGenerator` attributes. - Ultimately I (WPB) think we can aim for context classes consisting - of just a class attr for each variable, with anything complicated - happening in a decorator or base class. """ from modflow_devtools.dfn import _MF6_SCALARS @@ -274,7 +293,10 @@ def _args(): "namespace", "macros", "name", - "vars" + "vars", + "description", + "title", + "parent" ] dfn = {k: v for k, v in ctx.items() if k not in dfn_skip} if base == "MFPackage": diff --git a/flopy/mf6/utils/codegen/templates/macros.jinja b/flopy/mf6/utils/codegen/templates/macros.jinja index 318929015..4b3ddd5ff 100644 --- a/flopy/mf6/utils/codegen/templates/macros.jinja +++ b/flopy/mf6/utils/codegen/templates/macros.jinja @@ -2,7 +2,7 @@ {% for name, var in vars.items() if name not in skip %} {% set v = var|untag %} {% set n = (name if alias else v.name)|safe_name %} -{{ n }}{% if v.default is defined %}={{ v.default|value }}{% endif %}, +{{ n }}={{ v|default|value }}, {% endfor %} {% endmacro %} @@ -10,15 +10,17 @@ {% for var in vars.values() recursive %} {% set v = var|untag %} {% set n = v.name|safe_name|escape_trailing_underscore %} +{% set children = v|children %} {% if loop.depth > 1 %}* {% endif %}{{ n }} : {{ v|type }} {% if v.description is defined and v.description is not none %} {{ v.description|clean|math|wordwrap|indent(loop.depth * 4, first=true) }} {% endif %} -{% if recurse and v.children is defined and v.children is not none %} -{% if v.type == "list" and v.children|length == 1 and (v.children.values()|first).type in ["record", "union"] %} -{{ loop((v.children.values()|first).children.values())|indent(loop.depth * 4, first=true) }} +{% if recurse and children is not none %} +{% if v.type == "list" and children|length == 1 and (children.values()|first).type in ["record", "union"] %} +{% set grandchildren = (children.values()|first)|children %} +{{ loop(grandchildren.values())|indent(loop.depth * 4, first=true) }} {% else %} -{{ loop(v.children.values())|indent(loop.depth * 4, first=true) }} +{{ loop(children.values())|indent(loop.depth * 4, first=true) }} {% endif %} {% endif %} {% endfor %}