Skip to content

Commit

Permalink
sync changes as of train time
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdsharpe committed Mar 8, 2024
1 parent bbebf17 commit 9e6ac8d
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions training/gen2_architecture/train_blind_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cache_file = Path(__file__).parent / "nn-xxxlarge.pth"
print("Cache file: ", cache_file)


# Define the model
class Net(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -52,33 +53,33 @@ def forward(self, x: torch.Tensor):
# The goal here is to embed the invariant of "symmetry across alpha" into the network evaluation.

x_flipped = x.clone()
x_flipped[:, :8] = x[:, 8:16] * -1 # switch kulfan_lower with a flipped kulfan_upper
x_flipped[:, 8:16] = x[:, :8] * -1 # switch kulfan_upper with a flipped kulfan_lower
x_flipped[:, 16] *= -1 # flip kulfan_LE_weight
x_flipped[:, 18] *= -1 # flip sin(2a)
x_flipped[:, :8] = -1 * x[:, 8:16] # switch kulfan_lower with a flipped kulfan_upper
x_flipped[:, 8:16] = -1 * x[:, :8] # switch kulfan_upper with a flipped kulfan_lower
x_flipped[:, 16] = -1 * x[:, 16] # flip kulfan_LE_weight
x_flipped[:, 18] = -1 * x[:, 18] # flip sin(2a)
x_flipped[:, 23] = x[:, 24] # flip xtr_upper with xtr_lower
x_flipped[:, 24] = x[:, 23] # flip xtr_lower with xtr_upper

y_flipped = self.net(x_flipped)

### The resulting outputs will also be flipped, so we need to flip them back to their normal orientation
y_unflipped = y_flipped.clone()
y_unflipped[:, 1] *= -1 # CL
y_unflipped[:, 3] *= -1 # CM
y_unflipped[:, 1] = y_flipped[:, 1] * -1 # CL
y_unflipped[:, 3] = y_flipped[:, 3] * -1 # CM
y_unflipped[:, 4] = y_flipped[:, 5] # switch Top_Xtr with Bot_Xtr
y_unflipped[:, 5] = y_flipped[:, 4] # switch Bot_Xtr with Top_Xtr

# switch upper and lower Ret, H
y_unflipped[:, 6:6 + 32 * 2] = y_flipped[:, 6 + 32 * 3: 6 + 32 * 5]
y_unflipped[:, 6 + 32 * 3: 6 + 32 * 5] = y_flipped[:, 6:6 + 32 * 2]
y_unflipped[:, 6 + 32 * 0: 6 + 32 * 2] = y_flipped[:, 6 + 32 * 3: 6 + 32 * 5]
y_unflipped[:, 6 + 32 * 3: 6 + 32 * 5] = y_flipped[:, 6 + 32 * 0: 6 + 32 * 2]

# switch upper_bl_ue/vinf with lower_bl_ue/vinf
y_unflipped[:, 6 + 32 * 2: 6 + 32 * 3] = -1 * y_flipped[:, 6 + 32 * 5: 6 + 32 * 6]
y_unflipped[:, 6 + 32 * 5: 6 + 32 * 6] = -1 * y_flipped[:, 6 + 32 * 2: 6 + 32 * 3]

### Then, average the two outputs to get the "symmetric" result
y_fused = (y + y_unflipped) / 2
y_fused[:, 0] = torch.sigmoid(y_fused[:, 0]) # Analysis confidence, a binary variable
# y_fused[:, 0] = torch.sigmoid(y_fused[:, 0]) # Analysis confidence, a binary variable

return y_fused

Expand All @@ -97,7 +98,7 @@ def forward(self, x: torch.Tensor):

# Define the optimizer
learning_rate = 1e-3
optimizer = torch.optim.RAdam(net.parameters(), lr=learning_rate, weight_decay=1e-5)
optimizer = torch.optim.RAdam(net.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=0.5,
Expand All @@ -109,7 +110,7 @@ def forward(self, x: torch.Tensor):
# Define the data loader
print(f"Preparing data...")

batch_size = 128
batch_size = 256
train_inputs = torch.tensor(
df_train_inputs_scaled.to_numpy(),
dtype=torch.float32,
Expand Down Expand Up @@ -141,14 +142,15 @@ def forward(self, x: torch.Tensor):

# Prepare the loss function
loss_weights = torch.ones(N_outputs, dtype=torch.float32).to(device)
loss_weights[0] *= 0.05 # Analysis confidence
loss_weights[0] *= 0.005 # Analysis confidence
loss_weights[1] *= 1 # CL
loss_weights[2] *= 2 # ln(CD)
loss_weights[3] *= 0.5 # CM
loss_weights[2] *= 3 # ln(CD)
loss_weights[3] *= 0.25 # CM
loss_weights[4] *= 0.25 # Top Xtr
loss_weights[5] *= 0.25 # Bot Xtr
loss_weights[6:] *= 1 / (32 * 6) # Lower the weight on all boundary layer outputs
loss_weights[6:] *= 1 / 32 # Lower the weight on all boundary layer outputs

loss_weights = loss_weights / torch.sum(loss_weights) * 1000

def loss_function(y_pred, y_data, return_individual_loss_components=False):
# For data with NaN, overwrite the data with the prediction. This essentially makes the model ignore NaN data,
Expand All @@ -159,31 +161,47 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False):
y_data
)

analysis_confidence_loss = torch.nn.functional.binary_cross_entropy(y_pred[:, 0], y_data[:, 0])
analysis_confidence_loss = torch.mean(
torch.nn.functional.binary_cross_entropy_with_logits(
input=y_pred[:, 0:1],
target=y_data[:, 0:1],
reduction='none',
),
dim=0
)
# other_loss_components = torch.mean(
# (y_pred[:, 1:] - y_data[:, 1:]) ** 2,
# dim=0
# )

# other_loss_components = torch.mean(
# torch.nn.functional.huber_loss(
# y_pred[:, 1:], y_data[:, 1:],
# reduction='none',
# delta=1
# ),
# dim=0
# )

other_loss_components = torch.mean(
torch.nn.functional.huber_loss(
torch.nn.functional.mse_loss(
y_pred[:, 1:], y_data[:, 1:],
reduction='none',
delta=1
),
dim=0
)

unweighted_loss_components = torch.stack([
unweighted_loss_components = torch.concatenate([
analysis_confidence_loss,
*other_loss_components
other_loss_components
], dim=0)

weighted_loss_components = unweighted_loss_components * loss_weights
loss = torch.sum(weighted_loss_components)

if return_individual_loss_components:
return weighted_loss_components
else:
return loss
return torch.sum(weighted_loss_components)


# raise Exception
Expand All @@ -199,7 +217,8 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False):

loss_from_each_training_batch = []

for x, y_data in tqdm(train_loader):
# for x, y_data in tqdm(train_loader):
for x, y_data in train_loader:

x = x.to(device)
y_data = y_data.to(device)
Expand Down Expand Up @@ -237,6 +256,9 @@ def loss_function(y_pred, y_data, return_individual_loss_components=False):
)

loss_components_from_each_test_batch.append(loss_components)

y_pred[:, 0] = torch.sigmoid(y_pred[:, 0]) # Analysis confidence, a binary variable

mae_from_each_test_batch.append(
torch.nanmean(torch.abs(y_pred - y_data), dim=0)
)
Expand Down

0 comments on commit 9e6ac8d

Please sign in to comment.