diff --git a/pythonfmu/builder.py b/pythonfmu/builder.py index a96d51b..da9275d 100644 --- a/pythonfmu/builder.py +++ b/pythonfmu/builder.py @@ -3,11 +3,11 @@ import importlib import itertools import logging -import re import shutil import sys import tempfile import zipfile +import inspect from pathlib import Path from typing import Iterable, Optional, Tuple, Union from xml.dom.minidom import parseString @@ -21,11 +21,21 @@ logger = logging.getLogger(__name__) -def get_class_name(file_name: Path) -> str: - with open(str(file_name), 'r') as file: - data = file.read() - return re.search(r'class (\w+)\(\s*Fmi2Slave\s*\)\s*:', data).group(1) +def get_class_name(interface) -> str: + """Returns the name of the class derived from Fmi2Slave in the given interface module. + + Args: + interface: The module containing the classes to be inspected. + Returns: + str: The name of the class derived from Fmi2Slave, or None if no such class is found. + """ + candidate, mro = None, [] + for cl in [x for x in dir(interface) if inspect.isclass(getattr(interface, x))]: # get all classes in module and go through them + if any(m.__name__ == 'Fmi2Slave' for m in inspect.getmro(getattr(interface, cl))): # inspect the class hierarchy and return if 'Fmi2Slave' found + if getattr(interface, cl) not in mro: # must be a sub-class of the already registered (or first) + candidate, mro = cl, inspect.getmro(getattr(interface, cl)) + return candidate def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Element]: """Extract the FMU model description as XML. @@ -45,7 +55,7 @@ def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Elemen fmu_interface = importlib.util.module_from_spec(spec) spec.loader.exec_module(fmu_interface) # Instantiate the interface - class_name = get_class_name(filepath) + class_name = get_class_name(fmu_interface) instance = getattr(fmu_interface, class_name)(instance_name="dummyInstance", resources=str(filepath.parent)) finally: sys.path.remove(str(filepath.parent)) # remove inserted temporary path diff --git a/pythonfmu/pythonfmu-export/src/pythonfmu/PySlaveInstance.cpp b/pythonfmu/pythonfmu-export/src/pythonfmu/PySlaveInstance.cpp index 4112808..c0bd190 100644 --- a/pythonfmu/pythonfmu-export/src/pythonfmu/PySlaveInstance.cpp +++ b/pythonfmu/pythonfmu-export/src/pythonfmu/PySlaveInstance.cpp @@ -1,4 +1,3 @@ - #include "pythonfmu/IPyState.hpp" #include "pythonfmu/PySlaveInstance.hpp" @@ -8,6 +7,8 @@ #include #include +#include +#include #include #include #include @@ -24,19 +25,96 @@ inline std::string getline(const std::string& fileName) return line; } -inline std::string findClassName(const std::string& fileName) -{ +PyObject* findClass(const std::string& resources, const std::string& moduleName) { + // Initialize the Python interpreter + std::string filename = resources + "/" + moduleName + ".py"; + std::string deepestFile = ""; + int deepestChain = 0; + + // Read and execute the Python file + std::ifstream file; + file.open(filename); + + if (!file.is_open()) { + return nullptr; + } + + std::stringstream fileContents; std::string line; - std::ifstream infile(fileName); - std::string regexStr(R"(^class (\w+)\(\s*Fmi2Slave\s*\)\s*:)"); - while (getline(infile, line)) { - std::smatch m; - std::regex re(regexStr); - if (std::regex_search(line, m, re)) { - return m[1]; + + while (std::getline(file, line)) { + fileContents << line << "\n"; + } + + // Compile python code so classes are added to the namespace + PyObject* pyModule = PyImport_ImportModule(moduleName.c_str()); + + if (pyModule == nullptr) { + return nullptr; + } + PyObject* pGlobals = PyModule_GetDict(pyModule); + PyObject* pLocals = PyDict_New(); + PyObject* pCode = Py_CompileString(fileContents.str().c_str(), moduleName.c_str(), Py_file_input); + + if (pCode != NULL) { + PyObject* pResult = PyEval_EvalCode(pCode, pGlobals, pLocals); + Py_XDECREF(pResult); + } else { + PyErr_Print(); // Handle compilation error + Py_Finalize(); + Py_DECREF(pGlobals); + Py_DECREF(pyModule); + Py_DECREF(pLocals); + Py_DECREF(pCode); + file.close(); + return nullptr; + } + + fileContents.clear(); + PyObject* key, * value; + Py_ssize_t pos = 0; + + while (PyDict_Next(pLocals, &pos, &key, &value)) { + // Check if element in namespace is a class + if (!PyType_Check(value)) { + continue; + } + + PyObject* pMroAttribute = PyObject_GetAttrString(value, "__mro__"); + + if (pMroAttribute != NULL && PySequence_Check(pMroAttribute)) { + std::regex pattern (" deepestChain && match[1] == "Fmi2Slave") { + deepestFile = PyBytes_AsString(PyUnicode_AsUTF8String(key)); + deepestChain = i; + } + } } + Py_DECREF(pMroAttribute); } - return ""; + + PyObject* pyClassName = Py_BuildValue("s", deepestFile.c_str()); + PyObject* pyClass = PyObject_GetAttr(pyModule, pyClassName); + + // Clean up Python objects + Py_DECREF(pCode); + Py_DECREF(pLocals); + Py_DECREF(pyModule); + Py_DECREF(pGlobals); + file.close(); + Py_DECREF(pyClassName); + return pyClass; } inline void py_safe_run(const std::function& f) @@ -65,33 +143,17 @@ PySlaveInstance::PySlaveInstance(std::string instanceName, std::string resources handle_py_exception("[ctor] PyObject_GetAttrString", gilState); } int success = PyList_Insert(sys_path, 0, PyUnicode_FromString(resources_.c_str())); + Py_DECREF(sys_path); if (success != 0) { handle_py_exception("[ctor] PyList_Insert", gilState); } std::string moduleName = getline(resources_ + "/slavemodule.txt"); - PyObject* pModule = PyImport_ImportModule(moduleName.c_str()); - if (pModule == nullptr) { - handle_py_exception("[ctor] PyImport_ImportModule", gilState); - } - - std::string className = findClassName(resources_ + "/" + moduleName + ".py"); - if (className.empty()) { - cleanPyObject(); - throw cppfmu::FatalError("Unable to find class extending Fmi2Slave!"); - } - - PyObject* pyClassName = Py_BuildValue("s", className.c_str()); - if (pyClassName == nullptr) { - handle_py_exception("[ctor] Py_BuildValue", gilState); - } - pClass_ = PyObject_GetAttr(pModule, pyClassName); - Py_DECREF(pyClassName); - Py_DECREF(pModule); + pClass_ = findClass(resources_, moduleName); if (pClass_ == nullptr) { - handle_py_exception("[ctor] PyObject_GetAttr", gilState); + handle_py_exception("[ctor] findClass", gilState); } initialize(gilState);