Skip to content

Commit

Permalink
update hyperparams based on study
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdsharpe committed Mar 10, 2024
1 parent 9d85efd commit 48e54ec
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions training/gen2_architecture/train_blind_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward(self, x: torch.Tensor):
loss_weights[3] *= 0.25 # CM
loss_weights[4] *= 0.25 # Top Xtr
loss_weights[5] *= 0.25 # Bot Xtr
loss_weights[6:] *= 1 / 32 # Lower the weight on all boundary layer outputs
loss_weights[6:] *= 5e-3 # Lower the weight on all boundary layer outputs

loss_weights = loss_weights / torch.sum(loss_weights) * 1000

Expand All @@ -174,23 +174,23 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False):
# dim=0
# )

# other_loss_components = torch.mean(
# torch.nn.functional.huber_loss(
# y_pred[:, 1:], y_data[:, 1:],
# reduction='none',
# delta=1
# ),
# dim=0
# )

other_loss_components = torch.mean(
torch.nn.functional.mse_loss(
torch.nn.functional.huber_loss(
y_pred[:, 1:], y_data[:, 1:],
reduction='none',
delta=0.05
),
dim=0
)

# other_loss_components = torch.mean(
# torch.nn.functional.mse_loss(
# y_pred[:, 1:], y_data[:, 1:],
# reduction='none',
# ),
# dim=0
# )

unweighted_loss_components = torch.concatenate([
analysis_confidence_loss,
other_loss_components
Expand All @@ -206,7 +206,6 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False):

# raise Exception
print(f"Training...")
unweighted_epoch_loss_components = torch.ones(N_outputs, dtype=torch.float32).to(device)

n_batches_per_epoch = len(train_loader)

Expand Down

0 comments on commit 48e54ec

Please sign in to comment.