From 0b0e16a0fd338e57914b2acef252539c62b76bc7 Mon Sep 17 00:00:00 2001 From: Damian Rouson Date: Thu, 17 Oct 2024 23:51:05 -0700 Subject: [PATCH] feat(train_cloud_micro): assert no NaN derivatives --- demo/app/train-cloud-microphysics.F90 | 50 +++++++++------------------ demo/src/NetCDF_variable_m.f90 | 14 ++++++++ demo/src/NetCDF_variable_s.f90 | 33 ++++++++++++++++++ 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/demo/app/train-cloud-microphysics.F90 b/demo/app/train-cloud-microphysics.F90 index 58123a0f9..eea7565c4 100644 --- a/demo/app/train-cloud-microphysics.F90 +++ b/demo/app/train-cloud-microphysics.F90 @@ -6,7 +6,6 @@ program train_cloud_microphysics !! https://github.com/BerkeleyLab/icar. !! Intrinic modules : - use ieee_arithmetic, only : ieee_is_nan use iso_fortran_env, only : int64, real64 !! External dependencies: @@ -260,6 +259,7 @@ subroutine read_train_write(training_configuration, args, plot_file) associate(derivative_name => "d" // output_names(v)%string() // "/dt") print *,"- " // derivative_name derivative(v) = NetCDF_variable_t( input_variable(v) - output_variable(v) / dt, derivative_name) + call assert(.not. derivative(v)%any_nan(), "train_cloud_microhphysics: non NaN's") end associate end do end associate @@ -267,38 +267,20 @@ subroutine read_train_write(training_configuration, args, plot_file) end associate - !t_end = size(time_in) - - !associate(dt => real(time_out - time_in)) - ! do concurrent(t = 1:t_end) - ! dpt_dt(:,:,:,t) = (potential_temperature_out(:,:,:,t) - potential_temperature_in(:,:,:,t))/dt(t) - ! dqv_dt(:,:,:,t) = (qv_out(:,:,:,t)- qv_in(:,:,:,t))/dt(t) - ! dqc_dt(:,:,:,t) = (qc_out(:,:,:,t)- qc_in(:,:,:,t))/dt(t) - ! dqr_dt(:,:,:,t) = (qr_out(:,:,:,t)- qr_in(:,:,:,t))/dt(t) - ! dqs_dt(:,:,:,t) = (qs_out(:,:,:,t)- qs_in(:,:,:,t))/dt(t) - ! end do - !end associate - - !call assert(.not. any(ieee_is_nan(dpt_dt)), ".not. any(ieee_is_nan(dpt_dt)") - !call assert(.not. any(ieee_is_nan(dqv_dt)), ".not. any(ieee_is_nan(dqv_dt)") - !call assert(.not. any(ieee_is_nan(dqc_dt)), ".not. any(ieee_is_nan(dqc_dt)") - !call assert(.not. any(ieee_is_nan(dqr_dt)), ".not. any(ieee_is_nan(dqr_dt)") - !call assert(.not. any(ieee_is_nan(dqs_dt)), ".not. any(ieee_is_nan(dqs_dt)") - - !train_network: & - !block - ! type(trainable_network_t) trainable_network - ! type(mini_batch_t), allocatable :: mini_batches(:) - ! type(bin_t), allocatable :: bins(:) - ! type(input_output_pair_t), allocatable :: input_output_pairs(:) - ! type(tensor_t), allocatable, dimension(:) :: inputs, outputs - ! real, allocatable :: cost(:) - ! integer i, lon, lat, level, time, network_unit, io_status, epoch, end_step - ! integer(int64) start_training, finish_training + train_network: & + block + type(trainable_network_t) trainable_network + type(mini_batch_t), allocatable :: mini_batches(:) + type(bin_t), allocatable :: bins(:) + type(input_output_pair_t), allocatable :: input_output_pairs(:) + type(tensor_t), allocatable, dimension(:) :: inputs, outputs + real, allocatable :: cost(:) + integer i, lon, lat, level, time, network_unit, io_status, epoch, end_step + integer(int64) start_training, finish_training - ! associate( network_file => args%base_name // "_network.json") - - ! open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read') + associate( network_file => args%base_name // "_network.json") + + open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read') ! if (allocated(args%end_step)) then ! end_step = args%end_step @@ -482,8 +464,8 @@ subroutine read_train_write(training_configuration, args, plot_file) ! print *,"Training time: ", real(finish_training - start_training, real64)/real(clock_rate, real64),"for", & ! args%num_epochs,"epochs" - !end associate ! network_file - !end block train_network + end associate ! network_file + end block train_network !close(plot_file%plot_unit) diff --git a/demo/src/NetCDF_variable_m.f90 b/demo/src/NetCDF_variable_m.f90 index 9855bbfdc..c55b446b0 100644 --- a/demo/src/NetCDF_variable_m.f90 +++ b/demo/src/NetCDF_variable_m.f90 @@ -21,6 +21,8 @@ module NetCDF_variable_m procedure, private, non_overridable :: default_real_conformable_with, double_precision_conformable_with generic :: rank => default_real_rank, double_precision_rank procedure, private, non_overridable :: default_real_rank, double_precision_rank + generic :: any_nan => default_real_any_nan, double_precision_any_nan + procedure, private, non_overridable :: default_real_any_nan, double_precision_any_nan generic :: operator(-) => default_real_subtract, double_precision_subtract procedure, private, non_overridable :: default_real_subtract, double_precision_subtract generic :: operator(/) => default_real_divide, double_precision_divide @@ -155,6 +157,18 @@ elemental module subroutine double_precision_assign(lhs, rhs) type(NetCDF_variable_t(double_precision)), intent(in) :: rhs end subroutine + elemental module function default_real_any_nan(self) result(any_nan) + implicit none + class(NetCDF_variable_t), intent(in) :: self + logical any_nan + end function + + elemental module function double_precision_any_nan(self) result(any_nan) + implicit none + class(NetCDF_variable_t(double_precision)), intent(in) :: self + logical any_nan + end function + end interface end module \ No newline at end of file diff --git a/demo/src/NetCDF_variable_s.f90 b/demo/src/NetCDF_variable_s.f90 index 15b47864a..cdd4e9fc3 100644 --- a/demo/src/NetCDF_variable_s.f90 +++ b/demo/src/NetCDF_variable_s.f90 @@ -1,6 +1,7 @@ ! Copyright (c), The Regents of the University of California ! Terms of use are as specified in LICENSE.txt submodule(NetCDF_variable_m) NetCDF_variable_s + use ieee_arithmetic, only : ieee_is_nan use kind_parameters_m, only : default_real use assert_m, only : assert, intrinsic_array_t implicit none @@ -344,4 +345,36 @@ pure function double_precision_upper_bounds(NetCDF_variable) result(ubounds) call assert(lhs%rank()==rhs%rank(), "NetCDF_variable_s(double_precision_assign): ranks match)") end procedure + module procedure default_real_any_nan + + select case(self%rank()) + case(1) + any_nan = any(ieee_is_nan(self%values_1D_)) + case(2) + any_nan = any(ieee_is_nan(self%values_2D_)) + case(3) + any_nan = any(ieee_is_nan(self%values_3D_)) + case(4) + any_nan = any(ieee_is_nan(self%values_4D_)) + case default + error stop "NetCDF_variable_s(default_real_any_nan): unsupported rank)" + end select + end procedure + + module procedure double_precision_any_nan + + select case(self%rank()) + case(1) + any_nan = any(ieee_is_nan(self%values_1D_)) + case(2) + any_nan = any(ieee_is_nan(self%values_2D_)) + case(3) + any_nan = any(ieee_is_nan(self%values_3D_)) + case(4) + any_nan = any(ieee_is_nan(self%values_4D_)) + case default + error stop "NetCDF_variable_s(double_precision_any_nan): unsupported rank)" + end select + end procedure + end submodule NetCDF_variable_s \ No newline at end of file