Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual loss weights adaptation in TF2.0 #1656

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ def __init__(self, period=100, pde_points=True, bc_points=False):
self.period = period
self.pde_points = pde_points
self.bc_points = bc_points

self.num_bcs_initial = None
self.epochs_since_last_resample = 0

Expand All @@ -571,3 +570,50 @@ def on_epoch_end(self):
raise ValueError(
"`num_bcs` changed! Please update the loss function by `model.compile`."
)


class PrintLossWeight(Callback):
"""Print the loss weights every period epochs.

Args:
period: Interval (number of epochs) between printing loss weights.
"""

haison19952013 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, period):
super().__init__()
self.period = period
self.initial_loss_weights = None
self.current_loss_weights = None

def on_epoch_begin(self):
if self.model.train_state.epoch == 0:
self.initial_loss_weights = self.model.loss_weights.numpy().tolist()
else:
self.current_loss_weights = self.model.loss_weights.numpy().tolist()
if self.model.train_state.epoch % self.period == 0:
print("Initial loss weights:", self.initial_loss_weights)
print("Current loss weights:", self.current_loss_weights)


class ManualDynamicLossWeight(Callback):
"""Change the loss weights at a specific epoch.

Args:
epoch2change: The epoch at which to change the loss weight
value: The value to change the loss weight to
idx: The index of the loss weight to change
"""

def __init__(self, epoch2change, value, loss_idx):
super().__init__()
self.epoch2change = epoch2change
self.value = value
self.loss_idx = loss_idx

def on_epoch_begin(self):
if self.model.train_state.epoch == self.epoch2change:
current_loss_weights = self.model.loss_weights.numpy()
current_loss_weights[self.loss_idx] = self.value
self.model.loss_weights = tf.convert_to_tensor(
current_loss_weights, dtype=config.default_float()
)
47 changes: 33 additions & 14 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def compile(
print("Compiling model...")
self.opt_name = optimizer
loss_fn = losses_module.get(loss)
self.loss_weights = loss_weights
self.loss_weights = tf.convert_to_tensor(
Copy link
Owner

@lululxvi lululxvi Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • How about loss weights is None?
  • Using tf here will break other backends.

loss_weights, dtype=config.default_float()
)
if external_trainable_variables is None:
self.external_trainable_variables = []
else:
Expand Down Expand Up @@ -202,7 +204,9 @@ def _compile_tensorflow(self, lr, loss_fn, decay):
def outputs(training, inputs):
return self.net(inputs, training=training)

def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
def outputs_losses(
training, inputs, targets, auxiliary_vars, losses_fn, loss_weights
):
self.net.auxiliary_vars = auxiliary_vars
# Don't call outputs() decorated by @tf.function above, otherwise the
# gradient of outputs wrt inputs will be lost here.
Expand All @@ -218,29 +222,41 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
losses += [tf.math.reduce_sum(self.net.losses)]
losses = tf.convert_to_tensor(losses)
# Weighted losses
if self.loss_weights is not None:
losses *= self.loss_weights
if loss_weights is not None:
losses *= loss_weights
return outputs_, losses

@tf.function(jit_compile=config.xla_jit)
def outputs_losses_train(inputs, targets, auxiliary_vars):
def outputs_losses_train(inputs, targets, auxiliary_vars, loss_weights):
return outputs_losses(
True, inputs, targets, auxiliary_vars, self.data.losses_train
True,
inputs,
targets,
auxiliary_vars,
self.data.losses_train,
loss_weights,
)

@tf.function(jit_compile=config.xla_jit)
def outputs_losses_test(inputs, targets, auxiliary_vars):
def outputs_losses_test(inputs, targets, auxiliary_vars, loss_weights):
return outputs_losses(
False, inputs, targets, auxiliary_vars, self.data.losses_test
False,
inputs,
targets,
auxiliary_vars,
self.data.losses_test,
loss_weights,
)

opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)

@tf.function(jit_compile=config.xla_jit)
def train_step(inputs, targets, auxiliary_vars):
def train_step(inputs, targets, auxiliary_vars, loss_weights):
# inputs and targets are np.ndarray and automatically converted to Tensor.
with tf.GradientTape() as tape:
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
losses = outputs_losses_train(
inputs, targets, auxiliary_vars, loss_weights
)[1]
total_loss = tf.math.reduce_sum(losses)
trainable_variables = (
self.net.trainable_variables + self.external_trainable_variables
Expand Down Expand Up @@ -531,7 +547,7 @@ def _outputs(self, training, inputs):
outs = self.outputs(self.net.params, training, inputs)
return utils.to_numpy(outs)

def _outputs_losses(self, training, inputs, targets, auxiliary_vars):
def _outputs_losses(self, training, inputs, targets, auxiliary_vars, loss_weights):
if training:
outputs_losses = self.outputs_losses_train
else:
Expand All @@ -540,7 +556,7 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars):
feed_dict = self.net.feed_dict(training, inputs, targets, auxiliary_vars)
return self.sess.run(outputs_losses, feed_dict=feed_dict)
if backend_name == "tensorflow":
outs = outputs_losses(inputs, targets, auxiliary_vars)
outs = outputs_losses(inputs, targets, auxiliary_vars, loss_weights)
elif backend_name == "pytorch":
self.net.requires_grad_(requires_grad=False)
outs = outputs_losses(inputs, targets, auxiliary_vars)
Expand All @@ -552,12 +568,12 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars):
outs = outputs_losses(inputs, targets, auxiliary_vars)
return utils.to_numpy(outs[0]), utils.to_numpy(outs[1])

def _train_step(self, inputs, targets, auxiliary_vars):
def _train_step(self, inputs, targets, auxiliary_vars, loss_weights):
if backend_name == "tensorflow.compat.v1":
feed_dict = self.net.feed_dict(True, inputs, targets, auxiliary_vars)
self.sess.run(self.train_step, feed_dict=feed_dict)
elif backend_name in ["tensorflow", "paddle"]:
self.train_step(inputs, targets, auxiliary_vars)
self.train_step(inputs, targets, auxiliary_vars, loss_weights)
elif backend_name == "pytorch":
self.train_step(inputs, targets, auxiliary_vars)
elif backend_name == "jax":
Expand Down Expand Up @@ -669,6 +685,7 @@ def _train_sgd(self, iterations, display_every):
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
self.loss_weights,
)

self.train_state.epoch += 1
Expand Down Expand Up @@ -827,12 +844,14 @@ def _test(self):
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
self.loss_weights,
)
self.train_state.y_pred_test, self.train_state.loss_test = self._outputs_losses(
False,
self.train_state.X_test,
self.train_state.y_test,
self.train_state.test_aux_vars,
self.loss_weights,
)

if isinstance(self.train_state.y_test, (list, tuple)):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""

# import sys
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.callbacks import PrintLossWeight, ManualDynamicLossWeight

dde.config.disable_xla_jit()
from deepxde.backend import set_default_backend

set_default_backend("tensorflow")


def gen_traindata(num):
# generate num equally-spaced points from -1 to 1
xvals = np.linspace(-1, 1, num).reshape(num, 1)
uvals = np.sin(np.pi * xvals)
return xvals, uvals


def pde(x, y):
u, q = y[:, 0:1], y[:, 1:2]
du_xx = dde.grad.hessian(y, x, component=0, i=0, j=0)
return -du_xx + q


def sol(x):
# solution is u(x) = sin(pi*x), q(x) = -pi^2 * sin(pi*x)
return np.sin(np.pi * x)


geom = dde.geometry.Interval(-1, 1)
bc = dde.icbc.DirichletBC(geom, sol, lambda _, on_boundary: on_boundary, component=0)
ob_x, ob_u = gen_traindata(100)
observe_u = dde.icbc.PointSetBC(ob_x, ob_u, component=0)

data = dde.data.PDE(
geom,
pde,
[bc, observe_u],
num_domain=200,
num_boundary=2,
anchors=ob_x,
num_test=1000,
)

net = dde.nn.FNN([1, 40, 40, 40, 2], "tanh", "Glorot uniform")
PrintLossWeight_cb = PrintLossWeight(period=1)
ManualDynamicLossWeight_cb = ManualDynamicLossWeight(
epoch2change=5000, value=1, loss_idx=0
)
model = dde.Model(data, net)
model.compile("adam", lr=0.0001, loss_weights=[0, 100, 1000])
losshistory, train_state = model.train(
iterations=20000,
display_every=1,
callbacks=[PrintLossWeight_cb, ManualDynamicLossWeight_cb],
)
# dde.saveplot(losshistory, train_state, issave=True, isplot=True)

# view results
x = geom.uniform_points(500)
yhat = model.predict(x)
uhat, qhat = yhat[:, 0:1], yhat[:, 1:2]

utrue = np.sin(np.pi * x)
print("l2 relative error for u: " + str(dde.metrics.l2_relative_error(utrue, uhat)))
plt.figure()
plt.plot(x, utrue, "-", label="u_true")
plt.plot(x, uhat, "--", label="u_NN")
plt.legend()

qtrue = -np.pi**2 * np.sin(np.pi * x)
print("l2 relative error for q: " + str(dde.metrics.l2_relative_error(qtrue, qhat)))
plt.figure()
plt.plot(x, qtrue, "-", label="q_true")
plt.plot(x, qhat, "--", label="q_NN")
plt.legend()

plt.show()
Loading