Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/serde api features #107

Merged
merged 11 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/py-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']

env:
PYTHON: ${{ matrix.python-version }}
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
- macos
- windows
python-version:
- "9"
- "10"
- "11"
include:
Expand All @@ -36,7 +35,7 @@ jobs:
- name: set up python
uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.11"

- name: set up rust
if: matrix.os != 'ubuntu'
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ authors = [
description = "Tool for modeling and optimization of advanced locomotive powertrains for freight rail decarbonization."
readme = "README.md"
license = { file = "LICENSE.md" }
requires-python = ">=3.9, <3.12"
requires-python = ">=3.10, <3.12"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
Expand All @@ -47,6 +47,7 @@ dependencies = [
"pyarrow",
"requests",
"PyYAML==6.0.2",
"msgpack==1.1.0",
]

[project.urls]
Expand Down
75 changes: 62 additions & 13 deletions python/altrios/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,76 @@ def history_path_list(self, element_as_list:bool=False) -> List[str]:
item for item in self.variable_path_list(
element_as_list=element_as_list) if "history" in item_str(item)
]
return history_path_list

def to_pydict(self) -> Dict:
return history_path_list

# TODO connect to crate features
data_formats = [
'yaml',
'msg_pack',
# 'toml',
'json',
]

def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict:
"""
Returns self converted to pure python dictionary with no nested Rust objects
# Arguments
- `flatten`: if True, returns dict without any hierarchy
- `data_fmt`: data format for intermediate conversion step
"""
from yaml import load
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
pydict = load(self.to_yaml(), Loader = Loader)
return pydict
data_fmt = data_fmt.lower()
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
match data_fmt:
case "msg_pack":
import msgpack
pydict = msgpack.loads(self.to_msg_pack())
case "yaml":
from yaml import load
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
pydict = load(self.to_yaml(), Loader=Loader)
case "json":
from json import loads
pydict = loads(self.to_json())

if not flatten:
return pydict
else:
return next(iter(pd.json_normalize(pydict, sep=".").to_dict(orient='records')))

@classmethod
def from_pydict(cls, pydict: Dict) -> Self:
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack", skip_init: bool = True) -> Self:
"""
Instantiates Self from pure python dictionary
# Arguments
- `pydict`: dictionary to be converted to ALTRIOS object
- `data_fmt`: data format for intermediate conversion step
- `skip_init`: passed to `SerdeAPI` methods to control whether initialization
is skipped
"""
import yaml
return cls.from_yaml(yaml.dump(pydict),skip_init=False)
data_fmt = data_fmt.lower()
assert data_fmt in data_formats, f"`data_fmt` must be one of {data_formats}"
match data_fmt.lower():
case "yaml":
import yaml
obj = cls.from_yaml(yaml.dump(pydict), skip_init=skip_init)
case "msg_pack":
import msgpack
try:
obj = cls.from_msg_pack(
msgpack.packb(pydict), skip_init=skip_init)
except Exception as err:
print(
f"{err}\nFalling back to YAML.")
obj = cls.from_pydict(
pydict, data_fmt="yaml", skip_init=skip_init)
case "json":
from json import dumps
obj = cls.from_json(dumps(pydict), skip_init=skip_init)

return obj

def to_dataframe(self, pandas:bool=False) -> [pd.DataFrame, pl.DataFrame, pl.LazyFrame]:
"""
Expand Down
5 changes: 4 additions & 1 deletion python/altrios/altrios_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ class SerdeAPI(object):
@classmethod
def from_yaml(cls) -> Self: ...
@classmethod
def from_file(cls) -> Self: ...
def from_file(cls, skip_init=False) -> Self: ...
def to_file(self): ...
def to_bincode(self) -> bytes: ...
def to_json(self) -> str: ...
def to_yaml(self) -> str: ...
def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict: ...
@classmethod
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack") -> Self:


class Consist(SerdeAPI):
Expand Down
2 changes: 1 addition & 1 deletion python/altrios/demos/sim_manager_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
t0_import = time.perf_counter()
t0_total = time.perf_counter()

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(alt.resources_root() / "networks/default_locations.csv")
Expand Down
4 changes: 2 additions & 2 deletions python/altrios/demos/version_migration_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def migrate_network() -> Tuple[alt.Network, alt.Network]:
old_network_path = alt.resources_root() / "networks/Taconite_v0.1.6.yaml"
new_network_path = alt.resources_root() / "networks/Taconite.yaml"

network_from_old = alt.Network.from_file(old_network_path)
network_from_new = alt.Network.from_file(new_network_path)
network_from_old = alt.Network.from_file(old_network_path, skip_init=False)
network_from_new = alt.Network.from_file(new_network_path, skip_init=False)

# `network_from_old` could be used to overwrite the file in the new format with
# ```
Expand Down
4 changes: 2 additions & 2 deletions python/altrios/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def simulate_prescribed_rollout(
else:
demand_paths.append(demand_file)

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(
str(alt.resources_root() / "networks/default_locations.csv")
)
network = alt.Network.from_file(network_filename_path)
network = alt.Network.from_file(network_filename_path, skip_init=False)
sim_days = defaults.SIMULATION_DAYS
scenarios = []
for idx, scenario_year in enumerate(years):
Expand Down
125 changes: 125 additions & 0 deletions python/altrios/tests/test_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time
import altrios as alt

SAVE_INTERVAL = 100
def get_solved_speed_limit_train_sim():
# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
alt.resources_root() / "rolling_stock/Manifest_Loaded.yaml")
rail_vehicle_empty = alt.RailVehicle.from_file(
alt.resources_root() / "rolling_stock/Manifest_Empty.yaml")

# https://docs.rs/altrios-core/latest/altrios_core/train/struct.TrainConfig.html
train_config = alt.TrainConfig(
rail_vehicles=[rail_vehicle_loaded, rail_vehicle_empty],
n_cars_by_type={
"Manifest_Loaded": 50,
"Manifest_Empty": 50,
},
train_length_meters=None,
train_mass_kilograms=None,
)

# Build the locomotive consist model
# instantiate battery model
# https://docs.rs/altrios-core/latest/altrios_core/consist/locomotive/powertrain/reversible_energy_storage/struct.ReversibleEnergyStorage.html#
res = alt.ReversibleEnergyStorage.from_file(
alt.resources_root() / "powertrains/reversible_energy_storages/Kokam_NMC_75Ah_flx_drive.yaml"
)

edrv = alt.ElectricDrivetrain(
pwr_out_frac_interp=[0., 1.],
eta_interp=[0.98, 0.98],
pwr_out_max_watts=5e9,
save_interval=SAVE_INTERVAL,
)

bel: alt.Locomotive = alt.Locomotive.build_battery_electric_loco(
reversible_energy_storage=res,
drivetrain=edrv,
loco_params=alt.LocoParams.from_dict(dict(
pwr_aux_offset_watts=8.55e3,
pwr_aux_traction_coeff_ratio=540.e-6,
force_max_newtons=667.2e3,
)))

# construct a vector of one BEL and several conventional locomotives
loco_vec = [bel.clone()] + [alt.Locomotive.default()] * 7
# instantiate consist
loco_con = alt.Consist(
loco_vec
)

# Instantiate the intermediate `TrainSimBuilder`
tsb = alt.TrainSimBuilder(
train_id="0",
origin_id="A",
destination_id="B",
train_config=train_config,
loco_con=loco_con,
)

# Load the network and construct the timed link path through the network.
network = alt.Network.from_file(
alt.resources_root() / 'networks/simple_corridor_network.yaml')

location_map = alt.import_locations(
alt.resources_root() / "networks/simple_corridor_locations.csv")
train_sim: alt.SetSpeedTrainSim = tsb.make_speed_limit_train_sim(
location_map=location_map,
save_interval=1,
)
train_sim.set_save_interval(SAVE_INTERVAL)
est_time_net, _consist = alt.make_est_times(train_sim, network)

timed_link_path = alt.run_dispatch(
network,
alt.SpeedLimitTrainSimVec([train_sim]),
[est_time_net],
False,
False,
)[0]

train_sim.walk_timed_path(
network=network,
timed_path=timed_link_path,
)
assert len(train_sim.history) > 1

return train_sim


def test_pydict():
ts = get_solved_speed_limit_train_sim()

t0 = time.perf_counter_ns()
ts_dict_msg = ts.to_pydict(flatten=False, data_fmt="msg_pack")
ts_msg = alt.SpeedLimitTrainSim.from_pydict(
ts_dict_msg, data_fmt="msg_pack")
t1 = time.perf_counter_ns()
t_msg = t1 - t0
print(f"\nElapsed time for MessagePack: {t_msg:.3e} ns ")

t0 = time.perf_counter_ns()
ts_dict_yaml = ts.to_pydict(flatten=False, data_fmt="yaml")
ts_yaml = alt.SpeedLimitTrainSim.from_pydict(ts_dict_yaml, data_fmt="yaml")
t1 = time.perf_counter_ns()
t_yaml = t1 - t0
print(f"Elapsed time for YAML: {t_yaml:.3e} ns ")
print(f"YAML time per MessagePack time: {(t_yaml / t_msg):.3e} ")

t0 = time.perf_counter_ns()
ts_dict_json = ts.to_pydict(flatten=False, data_fmt="json")
_ts_json = alt.SpeedLimitTrainSim.from_pydict(
ts_dict_json, data_fmt="json")
t1 = time.perf_counter_ns()
t_json = t1 - t0
print(f"Elapsed time for json: {t_json:.3e} ns ")
print(f"JSON time per MessagePack time: {(t_json / t_msg):.3e} ")

# `to_pydict` is necessary because of some funkiness with direct equality comparison
assert ts_msg.to_pydict() == ts.to_pydict()
assert ts_yaml.to_pydict() == ts.to_pydict()

if __name__ == "__main__":
test_pydict()
2 changes: 1 addition & 1 deletion python/altrios/train_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ def run_train_planner(

if __name__ == "__main__":

rail_vehicles=[alt.RailVehicle.from_file(vehicle_file)
rail_vehicles=[alt.RailVehicle.from_file(vehicle_file, skip_init=False)
for vehicle_file in Path(alt.resources_root() / "rolling_stock/").glob('*.yaml')]

location_map = alt.import_locations(
Expand Down
23 changes: 23 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion rust/altrios-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rust-version = { workspace = true }
[dependencies]
csv = "1.1.6"
serde = { version = "1.0.136", features = ["derive"] }
rmp-serde = { version = "1.3.0", optional = true }
serde_yaml = "0.8.23"
serde_json = "1.0"
uom = { workspace = true, features = ["use_serde"] }
Expand Down Expand Up @@ -56,9 +57,13 @@ tempfile = "3.10.1"
derive_more = { version = "1.0.0", features = ["from_str", "from", "is_variant", "try_into"] }

[features]
default = []
default = ["serde-default"]
## Enables several text file formats for serialization and deserialization
serde-default = ["msgpack"]
## Exposes ALTRIOS structs, methods, and functions to Python.
pyo3 = ["dep:pyo3"]
## Enables message pack serialization and deserialization via `rmp-serde`
msgpack = ["dep:rmp-serde"]

[lints.rust]
# `'cfg(debug_advance_rewind)'` is expected for debugging in `advance_rewind.rs`
Expand Down
Loading
Loading