Skip to content

Commit

Permalink
T7042: drop use of inspect module in favor of ast for source analysis
Browse files Browse the repository at this point in the history
This avoids importing the config mode script as a module, with requisite
dependencies, which may be inconvenient.
  • Loading branch information
jestabro committed Jan 11, 2025
1 parent 9f6a986 commit d5b1bfc
Showing 1 changed file with 134 additions and 76 deletions.
210 changes: 134 additions & 76 deletions src/tests/test_configd_inspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2020-2024 VyOS maintainers and contributors
# Copyright (C) 2020-2025 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
Expand All @@ -12,93 +12,151 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import re
import ast
import json

import warnings
import importlib.util
from inspect import signature
from inspect import getsource
from functools import wraps
from unittest import TestCase

INC_FILE = 'data/configd-include.json'
CONF_DIR = 'src/conf_mode'

f_list = ['get_config', 'verify', 'generate', 'apply']

def import_script(s):
path = os.path.join(CONF_DIR, s)
name = os.path.splitext(s)[0].replace('-', '_')
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

# importing conf_mode scripts imports jinja2 with deprecation warning
def ignore_deprecation_warning(f):
@wraps(f)
def decorated_function(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(*args, **kwargs)
return decorated_function
funcs = ['get_config', 'verify', 'generate', 'apply']


class FunctionSig(ast.NodeVisitor):
def __init__(self):
self.func_sig_len = dict.fromkeys(funcs, None)
self.get_config_default_values = []

def visit_FunctionDef(self, node):
func_name = node.name
if func_name in funcs:
self.func_sig_len[func_name] = len(node.args.args)

if func_name == 'get_config':
for default in node.args.defaults:
if isinstance(default, ast.Constant):
self.get_config_default_values.append(default.value)

self.generic_visit(node)

def get_sig_lengths(self):
return self.func_sig_len

def get_config_default(self):
return self.get_config_default_values[0]


class LegacyCall(ast.NodeVisitor):
def __init__(self):
self.legacy_func_count = 0

def visit_Constant(self, node):
value = node.value
if isinstance(value, str):
if 'my_set' in value or 'my_delete' in value:
self.legacy_func_count += 1

self.generic_visit(node)

def get_legacy_func_count(self):
return self.legacy_func_count


class ConfigInstance(ast.NodeVisitor):
def __init__(self):
self.count = 0

def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
if name == 'Config':
self.count += 1
self.generic_visit(node)

def get_count(self):
return self.count


class FunctionConfigInstance(ast.NodeVisitor):
def __init__(self):
self.func_config_instance = dict.fromkeys(funcs, 0)

def visit_FunctionDef(self, node):
func_name = node.name
if func_name in funcs:
config_instance = ConfigInstance()
config_instance.visit(node)
self.func_config_instance[func_name] = config_instance.get_count()
self.generic_visit(node)

def get_func_config_instance(self):
return self.func_config_instance


class TestConfigdInspect(TestCase):
def setUp(self):
self.ast_list = []

with open(INC_FILE) as f:
self.inc_list = json.load(f)

@ignore_deprecation_warning
def test_signatures(self):
for s in self.inc_list:
m = import_script(s)
for i in f_list:
f = getattr(m, i, None)
self.assertIsNotNone(f, f"'{s}': missing function '{i}'")
sig = signature(f)
par = sig.parameters
l = len(par)
self.assertEqual(l, 1,
f"'{s}': '{i}' incorrect signature")
if i == 'get_config':
for p in par.values():
self.assertTrue(p.default is None,
f"'{s}': '{i}' incorrect signature")

@ignore_deprecation_warning
def test_function_instance(self):
for s in self.inc_list:
m = import_script(s)
for i in f_list:
f = getattr(m, i, None)
if not f:
continue
str_f = getsource(f)
# Regex not XXXConfig() T3108
n = len(re.findall(r'[^a-zA-Z]Config\(\)', str_f))
if i == 'get_config':
self.assertEqual(n, 1,
f"'{s}': '{i}' no instance of Config")
if i != 'get_config':
self.assertEqual(n, 0,
f"'{s}': '{i}' instance of Config")

@ignore_deprecation_warning
def test_file_instance(self):
for s in self.inc_list:
m = import_script(s)
str_m = getsource(m)
# Regex not XXXConfig T3108
n = len(re.findall(r'[^a-zA-Z]Config\(\)', str_m))
self.assertEqual(n, 1,
f"'{s}' more than one instance of Config")

@ignore_deprecation_warning
s_path = f'{CONF_DIR}/{s}'
with open(s_path) as f:
s_str = f.read()
s_tree = ast.parse(s_str)
self.ast_list.append((s, s_tree))

def test_signatures(self):
for s, t in self.ast_list:
visitor = FunctionSig()
visitor.visit(t)
sig_lens = visitor.get_sig_lengths()

for f in funcs:
self.assertIsNotNone(sig_lens[f], f"'{s}': '{f}' missing")
self.assertEqual(sig_lens[f], 1, f"'{s}': '{f}' incorrect signature")

self.assertEqual(
visitor.get_config_default(),
None,
f"'{s}': 'get_config' incorrect signature",
)

def test_file_config_instance(self):
for s, t in self.ast_list:
visitor = ConfigInstance()
visitor.visit(t)
count = visitor.get_count()

self.assertEqual(count, 1, f"'{s}' more than one instance of Config")

def test_function_config_instance(self):
for s, t in self.ast_list:
visitor = FunctionConfigInstance()
visitor.visit(t)
func_config_instance = visitor.get_func_config_instance()

for f in funcs:
if f == 'get_config':
self.assertTrue(
func_config_instance[f] > 0,
f"'{s}': '{f}' no instance of Config",
)
self.assertTrue(
func_config_instance[f] < 2,
f"'{s}': '{f}' more than one instance of Config",
)
else:
self.assertEqual(
func_config_instance[f], 0, f"'{s}': '{f}' instance of Config"
)

def test_config_modification(self):
for s in self.inc_list:
m = import_script(s)
str_m = getsource(m)
n = str_m.count('my_set')
self.assertEqual(n, 0, f"'{s}' modifies config")
for s, t in self.ast_list:
visitor = LegacyCall()
visitor.visit(t)
legacy_func_count = visitor.get_legacy_func_count()

self.assertEqual(legacy_func_count, 0, f"'{s}' modifies config")

0 comments on commit d5b1bfc

Please sign in to comment.