diff --git a/torax/transport_model/qlknn_10d.py b/torax/transport_model/qlknn_10d.py index 1f8aae77..3f8c8e1f 100644 --- a/torax/transport_model/qlknn_10d.py +++ b/torax/transport_model/qlknn_10d.py @@ -28,6 +28,7 @@ from torax.transport_model import base_qlknn_model from torax.transport_model import qualikiz_based_transport_model +# Internal import. # Internal import. # Move this to common lib. @@ -95,9 +96,7 @@ def __init__( self._model = MLP(hidden_sizes=hidden_sizes, activations=activations) def _load_prescale(self, key: str, names: list[str]) -> np.ndarray: - return np.array([self._model_config[key][k] for k in names])[ - np.newaxis, : - ] + return np.array([self._model_config[key][k] for k in names])[np.newaxis, :] def __call__( self, @@ -161,16 +160,12 @@ def predict( model_output = {} model_output['qi_itg'] = self.net_itgleading(inputs).clip(0) - model_output['qe_itg'] = ( - self.net_itgqediv(inputs) * model_output['qi_itg'] - ) + model_output['qe_itg'] = self.net_itgqediv(inputs) * model_output['qi_itg'] model_output['pfe_itg'] = ( self.net_itgpfediv(inputs) * model_output['qi_itg'] ) model_output['qe_tem'] = self.net_temleading(inputs).clip(0) - model_output['qi_tem'] = ( - self.net_temqidiv(inputs) * model_output['qe_tem'] - ) + model_output['qi_tem'] = self.net_temqidiv(inputs) * model_output['qe_tem'] model_output['pfe_tem'] = ( self.net_tempfediv(inputs) * model_output['qe_tem'] )