diff --git a/training/gen2_architecture/train_blind_neural_network.py b/training/gen2_architecture/train_blind_neural_network.py index 91b992d..c511b40 100644 --- a/training/gen2_architecture/train_blind_neural_network.py +++ b/training/gen2_architecture/train_blind_neural_network.py @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor): loss_weights[3] *= 0.25 # CM loss_weights[4] *= 0.25 # Top Xtr loss_weights[5] *= 0.25 # Bot Xtr - loss_weights[6:] *= 1 / 32 # Lower the weight on all boundary layer outputs + loss_weights[6:] *= 5e-3 # Lower the weight on all boundary layer outputs loss_weights = loss_weights / torch.sum(loss_weights) * 1000 @@ -174,23 +174,23 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False): # dim=0 # ) - # other_loss_components = torch.mean( - # torch.nn.functional.huber_loss( - # y_pred[:, 1:], y_data[:, 1:], - # reduction='none', - # delta=1 - # ), - # dim=0 - # ) - other_loss_components = torch.mean( - torch.nn.functional.mse_loss( + torch.nn.functional.huber_loss( y_pred[:, 1:], y_data[:, 1:], reduction='none', + delta=0.05 ), dim=0 ) + # other_loss_components = torch.mean( + # torch.nn.functional.mse_loss( + # y_pred[:, 1:], y_data[:, 1:], + # reduction='none', + # ), + # dim=0 + # ) + unweighted_loss_components = torch.concatenate([ analysis_confidence_loss, other_loss_components @@ -206,7 +206,6 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False): # raise Exception print(f"Training...") - unweighted_epoch_loss_components = torch.ones(N_outputs, dtype=torch.float32).to(device) n_batches_per_epoch = len(train_loader)