From ea157401996931081850818cef07a9e6d2c75e61 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Sat, 24 Feb 2024 19:46:54 +0000 Subject: [PATCH] chore(internal): add change case util functions Co-Authored-By: lxxonx --- src/prisma/generator/utils.py | 46 ++++++++++++++++++++++ tests/test_generation/test_utils.py | 61 ++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/prisma/generator/utils.py b/src/prisma/generator/utils.py index c7eca3443..dadc8dc14 100644 --- a/src/prisma/generator/utils.py +++ b/src/prisma/generator/utils.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import os +import re import shutil from typing import TYPE_CHECKING, Any, Dict, List, Union, TypeVar, Iterator from pathlib import Path @@ -122,3 +125,46 @@ def clean_multiline(string: str) -> str: assert string, 'Expected non-empty string' lines = string.splitlines() return '\n'.join([dedent(lines[0]), *lines[1:]]) + + +# https://github.com/nficano/humps/blob/master/humps/main.py + +ACRONYM_RE = re.compile(r'([A-Z\d]+)(?=[A-Z\d]|$)') +PASCAL_RE = re.compile(r'([^\-_]+)') +SPLIT_RE = re.compile(r'([\-_]*[A-Z][^A-Z]*[\-_]*)') +UNDERSCORE_RE = re.compile(r'(?<=[^\-_])[\-_]+[^\-_]') + + +def to_snake_case(input_str: str) -> str: + if to_camel_case(input_str) == input_str or to_pascal_case(input_str) == input_str: # if camel case or pascal case + input_str = ACRONYM_RE.sub(lambda m: m.group(0).title(), input_str) + input_str = '_'.join(s for s in SPLIT_RE.split(input_str) if s) + return input_str.lower() + else: + input_str = re.sub(r'[^a-zA-Z0-9]', '_', input_str) + input_str = input_str.lower().strip('_') + + return input_str + + +def to_camel_case(input_str: str) -> str: + if len(input_str) != 0 and not input_str[:2].isupper(): + input_str = input_str[0].lower() + input_str[1:] + return UNDERSCORE_RE.sub(lambda m: m.group(0)[-1].upper(), input_str) + + +def to_pascal_case(input_str: str) -> str: + def _replace_fn(match: re.Match[str]) -> str: + return match.group(1)[0].upper() + match.group(1)[1:] + + input_str = to_camel_case(PASCAL_RE.sub(_replace_fn, input_str)) + return input_str[0].upper() + input_str[1:] if len(input_str) != 0 else input_str + + +def to_constant_case(input_str: str) -> str: + """Converts to snake case + uppercase, examples: + + foo_bar -> FOO_BAR + fooBar -> FOO_BAR + """ + return to_snake_case(input_str).upper() diff --git a/tests/test_generation/test_utils.py b/tests/test_generation/test_utils.py index a77569cb4..207954ef8 100644 --- a/tests/test_generation/test_utils.py +++ b/tests/test_generation/test_utils.py @@ -1,4 +1,11 @@ -from prisma.generator.utils import copy_tree +import pytest +from prisma.generator.utils import ( + copy_tree, + to_camel_case, + to_constant_case, + to_snake_case, + to_pascal_case, +) from ..utils import Testdir @@ -26,3 +33,55 @@ def test_copy_tree_ignores_files(testdir: Testdir) -> None: assert files[0].name == 'bar.py' assert files[1].name == 'foo.py' assert files[2].name == 'hello.py' + + +@pytest.mark.parametrize( + 'input_str,expected', + [ + ('snake_case_test', 'snake_case_test'), + ('PascalCaseTest', 'pascal_case_test'), + ('camelCaseTest', 'camel_case_test'), + ('Mixed_Case_Test', 'mixed_case_test'), + ], +) +def test_to_snake_case(input_str: str, expected: str) -> None: + assert to_snake_case(input_str) == expected + + +@pytest.mark.parametrize( + 'input_str,expected', + [ + ('snake_case_test', 'SnakeCaseTest'), + ('PascalCaseTest', 'PascalCaseTest'), + ('camelCaseTest', 'CamelCaseTest'), + ('Mixed_Case_Test', 'MixedCaseTest'), + ], +) +def test_to_pascal_case(input_str: str, expected: str) -> None: + assert to_pascal_case(input_str) == expected + + +@pytest.mark.parametrize( + 'input_str,expected', + [ + ('snake_case_test', 'snakeCaseTest'), + ('PascalCaseTest', 'pascalCaseTest'), + ('camelCaseTest', 'camelCaseTest'), + ('Mixed_Case_Test', 'mixedCaseTest'), + ], +) +def test_to_camel_case(input_str: str, expected: str) -> None: + assert to_camel_case(input_str) == expected + + +@pytest.mark.parametrize( + 'input_str,expected', + [ + ('snake_case_test', 'SNAKE_CASE_TEST'), + ('PascalCaseTest', 'PASCAL_CASE_TEST'), + ('camelCaseTest', 'CAMEL_CASE_TEST'), + ('Mixed_Case_Test', 'MIXED_CASE_TEST'), + ], +) +def test_to_constant_case(input_str: str, expected: str) -> None: + assert to_constant_case(input_str) == expected