-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update MTVRP #176
base: main
Are you sure you want to change the base?
Update MTVRP #176
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Added some comments~
#durations = td["durations"][..., 1:] | ||
time_windows = td["time_windows"][..., 1:, :] | ||
# embeddings | ||
demands = td["demand_linehaul"][..., None] - td["demand_backhaul"][..., None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense; basically, if it's "-", the model will understand it is a backhaul. I was thinking about having a flag, but this is also good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A flag is also a good idea!
|
||
capacity = super()._state_embedding(embeddings, td) | ||
current_time = td["current_time"] | ||
current_length = td["current_route_length"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the model understand whether there is a limit?
In case there is no limit (say CVRP), then it will be the same as having VRPL, since the model does not know whether the constraint will be enforced or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. The route length in the state_emdding should be the rest length, i.e., length limit-current length, instead of current length... It seems to be a mistake...
@@ -256,7 +256,8 @@ def _default_open(td, remove): | |||
@staticmethod | |||
def _default_time_window(td, remove): | |||
default_tw = torch.zeros_like(td["time_windows"]) | |||
default_tw[..., 1] = float("inf") | |||
#default_tw[..., 1] = float("inf") | |||
default_tw[..., 1] = 4.6 # max tw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this influence the solution? If default time window is 4.6, the problem should not be a CVRP but a "relaxed" VRPTW.
The reason why I thought having "inf" is because it can generalize to any scale - for the embedding, this can be set as:
time_windows = torch.nan_to_num(td["time_windows"][..., 1:, :], posinf=0.0)
So it shouldn't influence the calculation as describe in Section 4.1 (Attribute composition) in your paper.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the default should be float("inf")
, as T=4.6
should only apply as default value to the environments where we actually want to model time windows!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that makes sense!
@@ -349,9 +349,14 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): | |||
curr_time = torch.max( | |||
curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] | |||
) | |||
|
|||
new_shape = curr_time.size() | |||
skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, good catch.
Anyways, I recommend setting check_solution
to False
when training; otherwise, the solution will be checked at each step and it can be a bit slow. I will add a warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't think this is necessary. Since skip_open_end
will only be true
if next_node == 0
, and since the depot has the highest time window end, curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1])
should always be True
except when curr_time
is very close to the max time already and then the duration in that last node is long enough to go over the time limit - is that something we want to allow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with ngastzepeda as the last node in the route should also satisfy the time window constraints (allow it back it depot even when it is OVRP). However, I find some outliners when training the MTVRP (i.e., the time window of the last node of OVRP route may exceed the max time window). I do not yet know the exact reason, instance generation, or masking procedure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will have a check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! 🚀
@@ -281,7 +281,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: | |||
& ~exceeds_dist_limit | |||
& ~td["visited"] | |||
) | |||
|
|||
#print(can_visit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] Debugging comments could be removed.
has_tw = (td["time_windows"][:, :, 1] != float("inf")).any(-1) | ||
has_tw = (td["time_windows"][:, :, 1] != 4.6).any(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the discussion with the _default_time_window()
, Changing this bound in the setting will need to modify this part. Any reason for this hardcode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid numerical issues during training as it will go through embedding, but I will have a check, the inf would be more general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -256,7 +256,8 @@ def _default_open(td, remove): | |||
@staticmethod | |||
def _default_time_window(td, remove): | |||
default_tw = torch.zeros_like(td["time_windows"]) | |||
default_tw[..., 1] = float("inf") | |||
#default_tw[..., 1] = float("inf") | |||
default_tw[..., 1] = 4.6 # max tw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the default should be float("inf")
, as T=4.6
should only apply as default value to the environments where we actually want to model time windows!
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP). | ||
Project the following to the embedding space: | ||
- current node embedding | ||
- remaining capacity (vehicle_capacity - used_capacity) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this also includes backhauls, we should mention this in the docs.
|
||
def forward(self, td): | ||
depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] | ||
#durations = td["durations"][..., 1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the durations not included?
Changes