-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update MTVRP #176
base: main
Are you sure you want to change the base?
Update MTVRP #176
Changes from all commits
d742f5b
1ac4516
096fc8e
044c789
99af6f9
991c516
1fd55d6
4e02795
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,7 +281,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: | |
& ~exceeds_dist_limit | ||
& ~td["visited"] | ||
) | ||
|
||
#print(can_visit) | ||
# Mask depot: don't visit depot if coming from there and there are still customer nodes I can visit | ||
can_visit[:, 0] = ~((curr_node == 0) & (can_visit[:, 1:].sum(-1) > 0)) | ||
return can_visit | ||
|
@@ -349,9 +349,14 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): | |
curr_time = torch.max( | ||
curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] | ||
) | ||
|
||
new_shape = curr_time.size() | ||
skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, good catch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I don't think this is necessary. Since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with ngastzepeda as the last node in the route should also satisfy the time window constraints (allow it back it depot even when it is OVRP). However, I find some outliners when training the MTVRP (i.e., the time window of the last node of OVRP route may exceed the max time window). I do not yet know the exact reason, instance generation, or masking procedure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will have a check. |
||
|
||
assert torch.all( | ||
curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] | ||
(curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end | ||
), "vehicle cannot start service before deadline" | ||
|
||
curr_time = curr_time + gather_by_index(td["service_time"], next_node) | ||
curr_node = next_node | ||
curr_time[curr_node == 0] = 0.0 # reset time for depot | ||
|
@@ -450,7 +455,7 @@ def _make_spec(self, td_params: TensorDict): | |
def check_variants(td): | ||
"""Check if the problem has the variants""" | ||
has_open = td["open_route"].squeeze(-1) | ||
has_tw = (td["time_windows"][:, :, 1] != float("inf")).any(-1) | ||
has_tw = (td["time_windows"][:, :, 1] != 4.6).any(-1) | ||
Comment on lines
-453
to
+458
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the discussion with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid numerical issues during training as it will go through embedding, but I will have a check, the inf would be more general. |
||
has_limit = (td["distance_limit"] != float("inf")).squeeze(-1) | ||
has_backhaul = (td["demand_backhaul"] != 0).any(-1) | ||
return has_open, has_tw, has_limit, has_backhaul | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,7 +256,8 @@ def _default_open(td, remove): | |
@staticmethod | ||
def _default_time_window(td, remove): | ||
default_tw = torch.zeros_like(td["time_windows"]) | ||
default_tw[..., 1] = float("inf") | ||
#default_tw[..., 1] = float("inf") | ||
default_tw[..., 1] = 4.6 # max tw | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this influence the solution? If default time window is 4.6, the problem should not be a CVRP but a "relaxed" VRPTW. The reason why I thought having "inf" is because it can generalize to any scale - for the embedding, this can be set as: time_windows = torch.nan_to_num(td["time_windows"][..., 1:, :], posinf=0.0) So it shouldn't influence the calculation as describe in Section 4.1 (Attribute composition) in your paper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the default should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, that makes sense! |
||
td["time_windows"][remove] = default_tw[remove] | ||
td["service_time"][remove] = torch.zeros_like(td["service_time"][remove]) | ||
return td | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: | |
"mtsp": MTSPContext, | ||
"smtwtp": SMTWTPContext, | ||
"mdcpdp": MDCPDPContext, | ||
"mtvrp": MTVRPContext | ||
} | ||
|
||
if env_name not in embedding_registry: | ||
|
@@ -146,6 +147,50 @@ def _state_embedding(self, embeddings, td): | |
state_embedding = td["vehicle_capacity"] - td["used_capacity"] | ||
return state_embedding | ||
|
||
class VRPBContext(EnvContext): | ||
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP). | ||
Project the following to the embedding space: | ||
- current node embedding | ||
- remaining capacity (vehicle_capacity - used_capacity) | ||
""" | ||
Comment on lines
+151
to
+155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this also includes backhauls, we should mention this in the docs. |
||
|
||
def __init__(self, embed_dim): | ||
super(VRPContext, self).__init__( | ||
embed_dim=embed_dim, step_context_dim=embed_dim + 1 | ||
) | ||
|
||
def _state_embedding(self, embeddings, td): | ||
mask = (td["used_capacity_backhaul"] == 0) | ||
used_capacity = torch.where(mask, td["used_capacity_linehaul"], td["used_capacity_backhaul"]) | ||
state_embedding = td["vehicle_capacity"] - used_capacity | ||
return state_embedding | ||
|
||
class MTVRPContext(VRPBContext): | ||
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP). | ||
Project the following to the embedding space: | ||
- current node embedding | ||
- remaining capacity (vehicle_capacity - used_capacity) | ||
- current time | ||
- current route length | ||
- if route should be open | ||
""" | ||
|
||
def __init__(self, embed_dim): | ||
super(VRPBContext, self).__init__( | ||
embed_dim=embed_dim, step_context_dim=embed_dim + 4 | ||
) | ||
|
||
def _state_embedding(self, embeddings, td): | ||
|
||
capacity = super()._state_embedding(embeddings, td) | ||
current_time = td["current_time"] | ||
current_length = td["current_route_length"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does the model understand whether there is a limit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. The route length in the state_emdding should be the rest length, i.e., length limit-current length, instead of current length... It seems to be a mistake... |
||
is_open = td["open_route"] | ||
is_open_tensor = torch.zeros_like(is_open, dtype=torch.float) | ||
is_open_tensor[is_open] = 1 | ||
|
||
return torch.cat([capacity, current_time, current_length, is_open_tensor], -1) | ||
|
||
|
||
class VRPTWContext(VRPContext): | ||
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP). | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,8 @@ | |
import torch.nn as nn | ||
|
||
from tensordict.tensordict import TensorDict | ||
|
||
from rl4co.models.nn.ops import PositionalEncoding | ||
|
||
|
||
def env_init_embedding(env_name: str, config: dict) -> nn.Module: | ||
"""Get environment initial embedding. The init embedding is used to initialize the | ||
general embedding of the problem nodes without any solution information. | ||
|
@@ -33,6 +31,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: | |
"smtwtp": SMTWTPInitEmbedding, | ||
"mdcpdp": MDCPDPInitEmbedding, | ||
"fjsp": FJSPFeatureEmbedding, | ||
"mtvrp":MTVRPInitEmbedding, | ||
} | ||
|
||
if env_name not in embedding_registry: | ||
|
@@ -146,6 +145,28 @@ def forward(self, td): | |
) | ||
) | ||
return torch.cat((depot_embedding, node_embeddings), -2) | ||
|
||
|
||
class MTVRPInitEmbedding(VRPInitEmbedding): | ||
def __init__(self, embed_dim, linear_bias=True, node_dim: int = 5): | ||
# node_dim = 5: x, y, demand, tw start, tw end | ||
super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim) | ||
|
||
def forward(self, td): | ||
depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] | ||
#durations = td["durations"][..., 1:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are the durations not included? |
||
time_windows = td["time_windows"][..., 1:, :] | ||
# embeddings | ||
demands = td["demand_linehaul"][..., None] - td["demand_backhaul"][..., None] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It makes sense; basically, if it's "-", the model will understand it is a backhaul. I was thinking about having a flag, but this is also good There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A flag is also a good idea! |
||
|
||
depot_embedding = self.init_embed_depot(depot) | ||
node_embeddings = self.init_embed( | ||
torch.cat( | ||
(cities, demands[:,1:], time_windows), -1 | ||
) | ||
) | ||
|
||
return torch.cat((depot_embedding, node_embeddings), -2) | ||
|
||
|
||
class SVRPInitEmbedding(nn.Module): | ||
|
@@ -383,7 +404,6 @@ def forward(self, td): | |
# concatenate on graph size dimension | ||
return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) | ||
|
||
|
||
class FJSPFeatureEmbedding(nn.Module): | ||
def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100): | ||
super().__init__() | ||
|
@@ -443,4 +463,4 @@ def _stepwise_operations_embed(self, td: TensorDict): | |
raise NotImplementedError("Stepwise encoding not yet implemented") | ||
|
||
def _stepwise_machine_embed(self, td: TensorDict): | ||
raise NotImplementedError("Stepwise encoding not yet implemented") | ||
raise NotImplementedError("Stepwise encoding not yet implemented") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] Debugging comments could be removed.