Skip to content

Commit

Permalink
fixed all remaining rust-side errors
Browse files Browse the repository at this point in the history
  • Loading branch information
calbaker committed Jan 13, 2025
1 parent c90b8e3 commit 03b4c81
Showing 1 changed file with 86 additions and 75 deletions.
161 changes: 86 additions & 75 deletions rust/altrios-core/src/train/train_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,9 +808,9 @@ pub fn run_speed_limit_train_sims(
}
};

let train_consist_plan: DataFrame = *train_consist_plan_py.as_ref();
let mut loco_pool: DataFrame = *loco_pool_py.as_ref();
let refuel_facilities: DataFrame = *refuel_facilities_py.as_ref();
let train_consist_plan: DataFrame = train_consist_plan_py.clone().into();
let mut loco_pool: DataFrame = loco_pool_py.clone().into();
let refuel_facilities: DataFrame = refuel_facilities_py.clone().into();

loco_pool = loco_pool
.lazy()
Expand Down Expand Up @@ -880,35 +880,48 @@ pub fn run_speed_limit_train_sims(
),
)?;
let arrivals = arrival_times.clone().filter(&arrivals_mask)?;
let arrivals_merged = loco_pool.clone().left_join(
&arrivals,
&["Locomotive_ID".into()],
&["Locomotive_ID".into()],
)?;
let arrivals_merged =
loco_pool
.clone()
.left_join(&arrivals, ["Locomotive_ID"], ["Locomotive_ID"])?;
let arrival_locations = arrivals_merged.column("Destination_ID")?;
if arrivals.height() > 0 {
let arrival_ids = arrivals.column("Locomotive_ID")?;
loco_pool = loco_pool
.lazy()
.with_columns(vec![
when(col("Locomotive_ID").is_in(lit(
*arrival_ids.as_series().with_context(|| format_dbg!())?,
)))
when(
col("Locomotive_ID").is_in(lit(arrival_ids
.clone()
.as_series()
.with_context(|| format_dbg!())?
.clone())),
)
.then(lit("Queued"))
.otherwise(col("Status"))
.alias("Status"),
when(col("Locomotive_ID").is_in(lit(
*arrival_ids.as_series().with_context(|| format_dbg!())?,
)))
when(
col("Locomotive_ID").is_in(lit(arrival_ids
.clone()
.as_series()
.with_context(|| format_dbg!())?
.clone())),
)
.then(lit(current_time))
.otherwise(col("Ready_Time_Est"))
.alias("Ready_Time_Est"),
when(col("Locomotive_ID").is_in(lit(
*arrival_ids.as_series().with_context(|| format_dbg!())?,
)))
.then(lit(*arrival_locations
when(
col("Locomotive_ID").is_in(lit(arrival_ids
.clone()
.as_series()
.with_context(|| format_dbg!())?
.clone())),
)
.then(lit(arrival_locations
.clone()
.as_series()
.with_context(|| format_dbg!())?))
.with_context(|| format_dbg!())?
.clone()))
.otherwise(col("Node"))
.alias("Node"),
])
Expand Down Expand Up @@ -938,14 +951,14 @@ pub fn run_speed_limit_train_sims(
col("SOC_Min_J"),
])?
.alias("SOC_Target_J")])
.sort("Locomotive_ID".into(), SortMultipleOptions::default())
.sort(["Locomotive_ID"], SortMultipleOptions::default())
.collect()
.with_context(|| format_dbg!())?;

let indices = arrivals.column("TrainSimVec_Index")?.u32()?.unique()?;
for index in indices.into_iter() {
let idx = index.unwrap() as usize;
let departing_soc_pct = train_consist_plan
let departing_soc_pct_iter = train_consist_plan
.clone()
.lazy()
// retain rows in which "TrainSimVec_Index" equals current `index`
Expand All @@ -960,11 +973,12 @@ pub fn run_speed_limit_train_sims(
[col("Locomotive_ID")],
JoinArgs::new(JoinType::Left),
)
.sort("Locomotive_ID".into(), SortMultipleOptions::default())
.sort(["Locomotive_ID"], SortMultipleOptions::default())
.with_columns(vec![(col("SOC_J") / col("Capacity_J")).alias("SOC_Pct")])
.collect()?
.collect()?;

let departing_soc_pct = departing_soc_pct_iter
.column("SOC_Pct")?
.clone()
.as_series()
.with_context(|| format_dbg!())?;

Expand Down Expand Up @@ -1028,13 +1042,14 @@ pub fn run_speed_limit_train_sims(
)?;
let arrival_locos = arrival_times.filter(&idx_mask)?;
let arrival_loco_ids = arrival_locos.column("Locomotive_ID")?.u32()?;
let arrival_loco_mask: Result<ChunkedArray<BooleanType>, PolarsError> = is_in(
let arrival_loco_mask: ChunkedArray<BooleanType> = is_in(
loco_pool
.column("Locomotive_ID")?
.as_series()
.with_context(|| format_dbg!())?,
&Series::from(*arrival_loco_ids),
);
&Series::from(arrival_loco_ids.clone()),
)
.with_context(|| format_dbg!())?;

// Get the indices of true values in the boolean ChunkedArray
let arrival_loco_indices: Vec<usize> = arrival_loco_mask
Expand All @@ -1055,33 +1070,29 @@ pub fn run_speed_limit_train_sims(
loco_pool = loco_pool
.lazy()
.with_columns(vec![
when(lit(arrival_loco_mask
.with_context(|| format_dbg!())?
.into()))
.then(lit(Series::new("SOC_J".into(), all_current_socs)))
.otherwise(col("SOC_J"))
.alias("SOC_J"),
when(lit(arrival_loco_mask
.with_context(|| format_dbg!())?
.into()))
.then(lit(Series::new("Trip_Energy_J".into(), all_energy_j)))
.otherwise(col("Trip_Energy_J"))
.alias("Trip_Energy_J"),
when(lit(arrival_loco_mask.clone().into_series()))
.then(lit(Series::new("SOC_J".into(), all_current_socs)))
.otherwise(col("SOC_J"))
.alias("SOC_J"),
when(lit(arrival_loco_mask.into_series()))
.then(lit(Series::new("Trip_Energy_J".into(), all_energy_j)))
.otherwise(col("Trip_Energy_J"))
.alias("Trip_Energy_J"),
])
.collect()
.with_context(|| format_dbg!())?;
}
loco_pool = loco_pool
.lazy()
.sort("Ready_Time_Est".into(), SortMultipleOptions::default())
.sort(["Ready_Time_Est"], SortMultipleOptions::default())
.collect()
.with_context(|| format_dbg!())?;
}

let refueling_mask = (loco_pool)
.column("Status")?
.equal(loco_pool.column("Refueling")?)?;
let refueling_finished_mask = refueling_mask
let refueling_finished_mask = refueling_mask.clone()
& (loco_pool).column("Ready_Time_Est")?.equal(
// TODO: Matt, remove this `Column::new` if you find a better way
&Column::new(
Expand All @@ -1093,7 +1104,7 @@ pub fn run_speed_limit_train_sims(
if refueling_finished_mask.sum().unwrap_or_default() > 0 {
loco_pool = loco_pool
.lazy()
.with_columns(vec![when(lit(refueling_finished_mask.into()))
.with_columns(vec![when(lit(refueling_finished_mask.into_series()))
.then(lit("Ready"))
.otherwise(col("Status"))
.alias("Status")])
Expand All @@ -1103,7 +1114,7 @@ pub fn run_speed_limit_train_sims(

if (arrivals.height() > 0) || (refueling_finished.height() > 0) {
// update queue
let place_in_queue = loco_pool
let place_in_queue_iter = loco_pool
.clone()
.lazy()
.select(&[((col("Status").eq(lit("Refueling")).sum().over([
Expand All @@ -1118,9 +1129,9 @@ pub fn run_speed_limit_train_sims(
"Fuel_Type",
])))
.alias("place_in_queue")])
.collect()?
.collect()?;
let place_in_queue = place_in_queue_iter
.column("place_in_queue")?
.clone()
.as_series()
.with_context(|| format_dbg!())?;
let future_times_mask = departure_times
Expand All @@ -1139,8 +1150,8 @@ pub fn run_speed_limit_train_sims(

let departures_merged = loco_pool.clone().left_join(
&next_departure_time,
&["Locomotive_ID".into()],
&["Locomotive_ID".into()],
["Locomotive_ID"],
["Locomotive_ID"],
)?;
let departure_times = departures_merged
.column("Departure_Time_Actual_Hr")?
Expand All @@ -1163,16 +1174,16 @@ pub fn run_speed_limit_train_sims(
.collect::<Vec<_>>();
let soc_target_series = Series::new("soc_target".into(), soc_target);

let refuel_end_time_ideal = loco_pool
let refuel_end_time_ideal_iter = loco_pool
.clone()
.lazy()
.select(&[(lit(current_time)
+ (max_horizontal([col("SOC_J"), col("SOC_Target_J")])? - col("SOC_J"))
/ col("Refueler_J_Per_Hr"))
.alias("Refuel_End_Time")])
.collect()?
.collect()?;
let refuel_end_time_ideal = refuel_end_time_ideal_iter
.column("Refuel_End_Time")?
.clone()
.as_series()
.with_context(|| format_dbg!())?;

Expand All @@ -1193,9 +1204,9 @@ pub fn run_speed_limit_train_sims(
loco_pool = loco_pool
.lazy()
.with_columns(vec![
lit(*place_in_queue),
lit(refuel_duration_series),
lit(refuel_end_series),
lit(place_in_queue.clone()),
lit(refuel_duration_series.clone()),
lit(refuel_end_series.clone()),
])
.collect()
.with_context(|| format_dbg!())?;
Expand All @@ -1213,22 +1224,22 @@ pub fn run_speed_limit_train_sims(
.with_context(|| format_dbg!())?;

let these_refuel_sessions = df![
"Node".into() => refuel_starting.column("Node").unwrap(),
"Locomotive_Type".into() => refuel_starting.column("Locomotive_Type").unwrap(),
"Fuel_Type".into() => refuel_starting.column("Fuel_Type").unwrap(),
"Locomotive_ID".into() => refuel_starting.column("Locomotive_ID").unwrap(),
"Refueler_J_Per_Hr".into() => refuel_starting.column("Refueler_J_Per_Hr").unwrap(),
"Refueler_Efficiency".into() => refuel_starting.column("Refueler_Efficiency").unwrap(),
"Trip_Energy_J".into() => refuel_starting.column("Trip_Energy_J").unwrap(),
"SOC_J".into() => refuel_starting.column("SOC_J").unwrap(),
"Refuel_Energy_J".into() => refuel_starting.clone().lazy().select(&[
(col("Refueler_J_Per_Hr")*col("refuel_duration")/col("Refueler_Efficiency")).alias("Refuel_Energy_J")
]).collect()?.column("Refuel_Energy_J")?.clone().as_series().with_context(|| format_dbg!()),
"Refuel_Duration_Hr".into() => refuel_starting.column("refuel_duration").unwrap(),
"Refuel_Start_Time_Hr".into() => refuel_starting.column("refuel_end_time").unwrap() -
refuel_starting.column("refuel_duration").unwrap(),
"Refuel_End_Time_Hr".into() => refuel_starting.column("refuel_end_time").unwrap()
]?;
PlSmallStr::from_str("Node") => refuel_starting.column("Node").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Locomotive_Type") => refuel_starting.column("Locomotive_Type").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Fuel_Type") => refuel_starting.column("Fuel_Type").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Locomotive_ID") => refuel_starting.column("Locomotive_ID").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refueler_J_Per_Hr") => refuel_starting.column("Refueler_J_Per_Hr").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refueler_Efficiency") => refuel_starting.column("Refueler_Efficiency").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Trip_Energy_J") => refuel_starting.column("Trip_Energy_J").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("SOC_J") => refuel_starting.column("SOC_J").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refuel_Energy_J") => refuel_starting.clone().lazy().select(&[
(col("Refueler_J_Per_Hr") * col("refuel_duration") / col("Refueler_Efficiency")).alias("Refuel_Energy_J")
]).collect()?.column("Refuel_Energy_J")?.clone().as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refuel_Duration_Hr") => refuel_starting.column("refuel_duration").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refuel_Start_Time_Hr") => (refuel_starting.column("refuel_end_time").with_context(|| format_dbg!())? -
refuel_starting.column("refuel_duration").with_context(|| format_dbg!())?).with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?,
PlSmallStr::from_str("Refuel_End_Time_Hr") => refuel_starting.column("refuel_end_time").with_context(|| format_dbg!())?.as_series().with_context(|| format_dbg!())?
].with_context(|| format_dbg!())?;
refuel_sessions.vstack_mut(&these_refuel_sessions)?;
// set finishedCharging times to min(max soc OR departure time)
loco_pool = loco_pool
Expand Down Expand Up @@ -1268,29 +1279,29 @@ pub fn run_speed_limit_train_sims(
loco_pool = loco_pool.drop("refuel_end_time")?;
}

let active_loco_ready_times = loco_pool
let active_loco_ready_times_iter = loco_pool
.clone()
.lazy()
.filter(col("Status").is_in(lit(active_loco_statuses.clone())))
.select(vec![col("Ready_Time_Est")])
.collect()?
.collect()?;
let active_loco_ready_times = active_loco_ready_times_iter
.column("Ready_Time_Est")?
.clone()
.as_series()
.with_context(|| format_dbg!())?;
arrival_times = arrival_times
.lazy()
.filter(col("Arrival_Time_Actual_Hr").gt(current_time))
.collect()?;
let arrival_times_remaining = arrival_times
let arrival_times_remaining_iter = arrival_times
.clone()
.lazy()
.select(vec![
col("Arrival_Time_Actual_Hr").alias("Arrival_Time_Actual_Hr")
])
.collect()?
.collect()?;
let arrival_times_remaining = arrival_times_remaining_iter
.column("Arrival_Time_Actual_Hr")?
.clone()
.as_series()
.with_context(|| format_dbg!())?;

Expand Down

0 comments on commit 03b4c81

Please sign in to comment.