Skip to content

Commit

Permalink
down to 3 errors!
Browse files Browse the repository at this point in the history
  • Loading branch information
calbaker committed Jan 13, 2025
1 parent 4593609 commit c933318
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions rust/altrios-core/src/train/train_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ pub fn run_speed_limit_train_sims(
)?;
let arrival_locos = arrival_times.filter(&idx_mask)?;
let arrival_loco_ids = arrival_locos.column("Locomotive_ID")?.u32()?;
let arrival_loco_mask = is_in(
let arrival_loco_mask: Result<ChunkedArray<BooleanType>, PolarsError> = is_in(
loco_pool
.column("Locomotive_ID")?
.as_series()
Expand All @@ -1040,6 +1040,7 @@ pub fn run_speed_limit_train_sims(
let arrival_loco_indices: Vec<usize> = arrival_loco_mask
.into_iter()
.enumerate()
// TODO: Matt, help Chad figure out what to replace `unwrap_or_default` with
.filter(|(_, val)| val.unwrap_or_default())
.map(|(i, _)| i)
.collect();
Expand All @@ -1055,14 +1056,14 @@ pub fn run_speed_limit_train_sims(
.lazy()
.with_columns(vec![
when(lit(arrival_loco_mask
.as_series()
.with_context(|| format_dbg!())?))
.with_context(|| format_dbg!())?
.into()))
.then(lit(Series::new("SOC_J".into(), all_current_socs)))
.otherwise(col("SOC_J"))
.alias("SOC_J"),
when(lit(arrival_loco_mask
.as_series()
.with_context(|| format_dbg!())))
.with_context(|| format_dbg!())?
.into()))
.then(lit(Series::new("Trip_Energy_J".into(), all_energy_j)))
.otherwise(col("Trip_Energy_J"))
.alias("Trip_Energy_J"),
Expand All @@ -1080,18 +1081,22 @@ pub fn run_speed_limit_train_sims(
let refueling_mask = (loco_pool)
.column("Status")?
.equal(loco_pool.column("Refueling")?)?;
let refueling_finished_mask =
refueling_mask & (loco_pool).column("Ready_Time_Est")?.equal(current_time)?;
let refueling_finished_mask = refueling_mask
& (loco_pool).column("Ready_Time_Est")?.equal(
// TODO: Matt, remove this `Column::new` if you find a better way
&Column::new(
"current_time_const".into(),
vec![current_time; refueling_mask.len()],
),
)?;
let refueling_finished = loco_pool.clone().filter(&refueling_finished_mask)?;
if refueling_finished_mask.sum().unwrap_or_default() > 0 {
loco_pool = loco_pool
.lazy()
.with_columns(vec![when(lit(refueling_finished_mask
.as_series()
.with_context(|| format_dbg!())))
.then(lit("Ready"))
.otherwise(col("Status"))
.alias("Status")])
.with_columns(vec![when(lit(refueling_finished_mask.into()))
.then(lit("Ready"))
.otherwise(col("Status"))
.alias("Status")])
.collect()
.with_context(|| format_dbg!())?;
}
Expand Down

0 comments on commit c933318

Please sign in to comment.