Skip to content

Commit

Permalink
FEATURE: Support inheritance from Fmi2Slave (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgelmh authored Sep 26, 2024
1 parent e12ac47 commit 917abf9
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 36 deletions.
22 changes: 16 additions & 6 deletions pythonfmu/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import importlib
import itertools
import logging
import re
import shutil
import sys
import tempfile
import zipfile
import inspect
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union
from xml.dom.minidom import parseString
Expand All @@ -21,11 +21,21 @@
logger = logging.getLogger(__name__)


def get_class_name(file_name: Path) -> str:
with open(str(file_name), 'r') as file:
data = file.read()
return re.search(r'class (\w+)\(\s*Fmi2Slave\s*\)\s*:', data).group(1)
def get_class_name(interface) -> str:
"""Returns the name of the class derived from Fmi2Slave in the given interface module.
Args:
interface: The module containing the classes to be inspected.
Returns:
str: The name of the class derived from Fmi2Slave, or None if no such class is found.
"""
candidate, mro = None, []
for cl in [x for x in dir(interface) if inspect.isclass(getattr(interface, x))]: # get all classes in module and go through them
if any(m.__name__ == 'Fmi2Slave' for m in inspect.getmro(getattr(interface, cl))): # inspect the class hierarchy and return if 'Fmi2Slave' found
if getattr(interface, cl) not in mro: # must be a sub-class of the already registered (or first)
candidate, mro = cl, inspect.getmro(getattr(interface, cl))

return candidate

def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Element]:
"""Extract the FMU model description as XML.
Expand All @@ -45,7 +55,7 @@ def get_model_description(filepath: Path, module_name: str) -> Tuple[str, Elemen
fmu_interface = importlib.util.module_from_spec(spec)
spec.loader.exec_module(fmu_interface)
# Instantiate the interface
class_name = get_class_name(filepath)
class_name = get_class_name(fmu_interface)
instance = getattr(fmu_interface, class_name)(instance_name="dummyInstance", resources=str(filepath.parent))
finally:
sys.path.remove(str(filepath.parent)) # remove inserted temporary path
Expand Down
122 changes: 92 additions & 30 deletions pythonfmu/pythonfmu-export/src/pythonfmu/PySlaveInstance.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include "pythonfmu/IPyState.hpp"
#include "pythonfmu/PySlaveInstance.hpp"

Expand All @@ -8,6 +7,8 @@

#include <fstream>
#include <functional>
#include <filesystem>
#include <string>
#include <mutex>
#include <regex>
#include <sstream>
Expand All @@ -24,19 +25,96 @@ inline std::string getline(const std::string& fileName)
return line;
}

inline std::string findClassName(const std::string& fileName)
{
PyObject* findClass(const std::string& resources, const std::string& moduleName) {
// Initialize the Python interpreter
std::string filename = resources + "/" + moduleName + ".py";
std::string deepestFile = "";
int deepestChain = 0;

// Read and execute the Python file
std::ifstream file;
file.open(filename);

if (!file.is_open()) {
return nullptr;
}

std::stringstream fileContents;
std::string line;
std::ifstream infile(fileName);
std::string regexStr(R"(^class (\w+)\(\s*Fmi2Slave\s*\)\s*:)");
while (getline(infile, line)) {
std::smatch m;
std::regex re(regexStr);
if (std::regex_search(line, m, re)) {
return m[1];

while (std::getline(file, line)) {
fileContents << line << "\n";
}

// Compile python code so classes are added to the namespace
PyObject* pyModule = PyImport_ImportModule(moduleName.c_str());

if (pyModule == nullptr) {
return nullptr;
}
PyObject* pGlobals = PyModule_GetDict(pyModule);
PyObject* pLocals = PyDict_New();
PyObject* pCode = Py_CompileString(fileContents.str().c_str(), moduleName.c_str(), Py_file_input);

if (pCode != NULL) {
PyObject* pResult = PyEval_EvalCode(pCode, pGlobals, pLocals);
Py_XDECREF(pResult);
} else {
PyErr_Print(); // Handle compilation error
Py_Finalize();
Py_DECREF(pGlobals);
Py_DECREF(pyModule);
Py_DECREF(pLocals);
Py_DECREF(pCode);
file.close();
return nullptr;
}

fileContents.clear();
PyObject* key, * value;
Py_ssize_t pos = 0;

while (PyDict_Next(pLocals, &pos, &key, &value)) {
// Check if element in namespace is a class
if (!PyType_Check(value)) {
continue;
}

PyObject* pMroAttribute = PyObject_GetAttrString(value, "__mro__");

if (pMroAttribute != NULL && PySequence_Check(pMroAttribute)) {
std::regex pattern ("<class '[^']+\\.([^']+)'");
PyObject* pMROList = PySequence_List(pMroAttribute);

for (Py_ssize_t i = 0; i < PyList_Size(pMROList); ++i) {
PyObject* pItem = PyList_GetItem(pMROList, i);
std::smatch match;
const char* className = PyBytes_AsString(PyUnicode_AsUTF8String(PyObject_Repr(pItem)));

std::string str (className);
bool isMatch = std::regex_search(str, match, pattern);

// If regex match is successfull, and found Fmi2Slave at the deepest level then update state
if (isMatch && i > deepestChain && match[1] == "Fmi2Slave") {
deepestFile = PyBytes_AsString(PyUnicode_AsUTF8String(key));
deepestChain = i;
}
}
}
Py_DECREF(pMroAttribute);
}
return "";

PyObject* pyClassName = Py_BuildValue("s", deepestFile.c_str());
PyObject* pyClass = PyObject_GetAttr(pyModule, pyClassName);

// Clean up Python objects
Py_DECREF(pCode);
Py_DECREF(pLocals);
Py_DECREF(pyModule);
Py_DECREF(pGlobals);
file.close();
Py_DECREF(pyClassName);
return pyClass;
}

inline void py_safe_run(const std::function<void(PyGILState_STATE gilState)>& f)
Expand Down Expand Up @@ -65,33 +143,17 @@ PySlaveInstance::PySlaveInstance(std::string instanceName, std::string resources
handle_py_exception("[ctor] PyObject_GetAttrString", gilState);
}
int success = PyList_Insert(sys_path, 0, PyUnicode_FromString(resources_.c_str()));

Py_DECREF(sys_path);
if (success != 0) {
handle_py_exception("[ctor] PyList_Insert", gilState);
}

std::string moduleName = getline(resources_ + "/slavemodule.txt");
PyObject* pModule = PyImport_ImportModule(moduleName.c_str());
if (pModule == nullptr) {
handle_py_exception("[ctor] PyImport_ImportModule", gilState);
}

std::string className = findClassName(resources_ + "/" + moduleName + ".py");
if (className.empty()) {
cleanPyObject();
throw cppfmu::FatalError("Unable to find class extending Fmi2Slave!");
}

PyObject* pyClassName = Py_BuildValue("s", className.c_str());
if (pyClassName == nullptr) {
handle_py_exception("[ctor] Py_BuildValue", gilState);
}

pClass_ = PyObject_GetAttr(pModule, pyClassName);
Py_DECREF(pyClassName);
Py_DECREF(pModule);
pClass_ = findClass(resources_, moduleName);
if (pClass_ == nullptr) {
handle_py_exception("[ctor] PyObject_GetAttr", gilState);
handle_py_exception("[ctor] findClass", gilState);
}

initialize(gilState);
Expand Down

0 comments on commit 917abf9

Please sign in to comment.