diff --git a/onshape_api/robot.py b/onshape_api/robot.py index d60fee2..a549e6a 100644 --- a/onshape_api/robot.py +++ b/onshape_api/robot.py @@ -10,7 +10,7 @@ import xml.etree.ElementTree as ET from enum import Enum from pathlib import Path -from typing import Optional +from typing import Optional, Union from lxml import etree @@ -113,6 +113,8 @@ def __init__( self.assets = assets self.type = robot_type + self._parent_directory: Union[str, None] = None + self.element: ET.Element = element if element is not None else self.to_xml(robot_type=self.type) self.tree: ET.ElementTree = tree if tree is not None else ET.ElementTree(self.element) @@ -156,6 +158,18 @@ def save(self, file_path: Optional[str] = None, download_assets: bool = True) -> >>> robot = Robot( ... ) >>> robot.save() """ + if file_path and self._parent_directory: + file_path = Path(file_path) + if Path(file_path).parent != Path(self._parent_directory): + LOGGER.warning( + "Parent directory of the file path provided is " + "different from the parent directory of the URDF file." + ) + LOGGER.warning( + "If mesh files are present in the URDF file, " + "then the URDF file will not be able to find the mesh files." + ) + path = file_path if file_path else f"{self.name}.{self.type}" if download_assets: @@ -166,8 +180,12 @@ def save(self, file_path: Optional[str] = None, download_assets: bool = True) -> xml_tree = etree.fromstring(xml_str) # noqa: S320 pretty_xml_str = etree.tostring(xml_tree, pretty_print=True, encoding="unicode") + # Add XML declaration + xml_declaration = '\n' + full_xml_str = xml_declaration + pretty_xml_str + with open(path, "w", encoding="utf-8") as f: - f.write(pretty_xml_str) + f.write(full_xml_str) LOGGER.info(f"Robot model saved to {path}") @@ -249,9 +267,11 @@ def from_urdf(cls, filename: str) -> "Robot": if joint: joints[joint.name] = joint - return Robot( + _cls = cls( name=name, links=links, joints=joints, assets=None, robot_type=RobotType.URDF, element=root, tree=tree ) + _cls._parent_directory = Path(filename).parent + return _cls def get_robot(