Skip to content

Commit

Permalink
feat(train_cloud_micro): assert no NaN derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 18, 2024
1 parent 84c24bb commit 0b0e16a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 34 deletions.
50 changes: 16 additions & 34 deletions demo/app/train-cloud-microphysics.F90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ program train_cloud_microphysics
!! https://github.com/BerkeleyLab/icar.

!! Intrinic modules :
use ieee_arithmetic, only : ieee_is_nan
use iso_fortran_env, only : int64, real64

!! External dependencies:
Expand Down Expand Up @@ -260,45 +259,28 @@ subroutine read_train_write(training_configuration, args, plot_file)
associate(derivative_name => "d" // output_names(v)%string() // "/dt")
print *,"- " // derivative_name
derivative(v) = NetCDF_variable_t( input_variable(v) - output_variable(v) / dt, derivative_name)
call assert(.not. derivative(v)%any_nan(), "train_cloud_microhphysics: non NaN's")
end associate
end do
end associate
end block

end associate

!t_end = size(time_in)

!associate(dt => real(time_out - time_in))
! do concurrent(t = 1:t_end)
! dpt_dt(:,:,:,t) = (potential_temperature_out(:,:,:,t) - potential_temperature_in(:,:,:,t))/dt(t)
! dqv_dt(:,:,:,t) = (qv_out(:,:,:,t)- qv_in(:,:,:,t))/dt(t)
! dqc_dt(:,:,:,t) = (qc_out(:,:,:,t)- qc_in(:,:,:,t))/dt(t)
! dqr_dt(:,:,:,t) = (qr_out(:,:,:,t)- qr_in(:,:,:,t))/dt(t)
! dqs_dt(:,:,:,t) = (qs_out(:,:,:,t)- qs_in(:,:,:,t))/dt(t)
! end do
!end associate

!call assert(.not. any(ieee_is_nan(dpt_dt)), ".not. any(ieee_is_nan(dpt_dt)")
!call assert(.not. any(ieee_is_nan(dqv_dt)), ".not. any(ieee_is_nan(dqv_dt)")
!call assert(.not. any(ieee_is_nan(dqc_dt)), ".not. any(ieee_is_nan(dqc_dt)")
!call assert(.not. any(ieee_is_nan(dqr_dt)), ".not. any(ieee_is_nan(dqr_dt)")
!call assert(.not. any(ieee_is_nan(dqs_dt)), ".not. any(ieee_is_nan(dqs_dt)")

!train_network: &
!block
! type(trainable_network_t) trainable_network
! type(mini_batch_t), allocatable :: mini_batches(:)
! type(bin_t), allocatable :: bins(:)
! type(input_output_pair_t), allocatable :: input_output_pairs(:)
! type(tensor_t), allocatable, dimension(:) :: inputs, outputs
! real, allocatable :: cost(:)
! integer i, lon, lat, level, time, network_unit, io_status, epoch, end_step
! integer(int64) start_training, finish_training
train_network: &
block
type(trainable_network_t) trainable_network
type(mini_batch_t), allocatable :: mini_batches(:)
type(bin_t), allocatable :: bins(:)
type(input_output_pair_t), allocatable :: input_output_pairs(:)
type(tensor_t), allocatable, dimension(:) :: inputs, outputs
real, allocatable :: cost(:)
integer i, lon, lat, level, time, network_unit, io_status, epoch, end_step
integer(int64) start_training, finish_training

! associate( network_file => args%base_name // "_network.json")
! open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')
associate( network_file => args%base_name // "_network.json")

open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')

! if (allocated(args%end_step)) then
! end_step = args%end_step
Expand Down Expand Up @@ -482,8 +464,8 @@ subroutine read_train_write(training_configuration, args, plot_file)
! print *,"Training time: ", real(finish_training - start_training, real64)/real(clock_rate, real64),"for", &
! args%num_epochs,"epochs"

!end associate ! network_file
!end block train_network
end associate ! network_file
end block train_network

!close(plot_file%plot_unit)

Expand Down
14 changes: 14 additions & 0 deletions demo/src/NetCDF_variable_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ module NetCDF_variable_m
procedure, private, non_overridable :: default_real_conformable_with, double_precision_conformable_with
generic :: rank => default_real_rank, double_precision_rank
procedure, private, non_overridable :: default_real_rank, double_precision_rank
generic :: any_nan => default_real_any_nan, double_precision_any_nan
procedure, private, non_overridable :: default_real_any_nan, double_precision_any_nan
generic :: operator(-) => default_real_subtract, double_precision_subtract
procedure, private, non_overridable :: default_real_subtract, double_precision_subtract
generic :: operator(/) => default_real_divide, double_precision_divide
Expand Down Expand Up @@ -155,6 +157,18 @@ elemental module subroutine double_precision_assign(lhs, rhs)
type(NetCDF_variable_t(double_precision)), intent(in) :: rhs
end subroutine

elemental module function default_real_any_nan(self) result(any_nan)
implicit none
class(NetCDF_variable_t), intent(in) :: self
logical any_nan
end function

elemental module function double_precision_any_nan(self) result(any_nan)
implicit none
class(NetCDF_variable_t(double_precision)), intent(in) :: self
logical any_nan
end function

end interface

end module
33 changes: 33 additions & 0 deletions demo/src/NetCDF_variable_s.f90
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
submodule(NetCDF_variable_m) NetCDF_variable_s
use ieee_arithmetic, only : ieee_is_nan
use kind_parameters_m, only : default_real
use assert_m, only : assert, intrinsic_array_t
implicit none
Expand Down Expand Up @@ -344,4 +345,36 @@ pure function double_precision_upper_bounds(NetCDF_variable) result(ubounds)
call assert(lhs%rank()==rhs%rank(), "NetCDF_variable_s(double_precision_assign): ranks match)")
end procedure

module procedure default_real_any_nan

select case(self%rank())
case(1)
any_nan = any(ieee_is_nan(self%values_1D_))
case(2)
any_nan = any(ieee_is_nan(self%values_2D_))
case(3)
any_nan = any(ieee_is_nan(self%values_3D_))
case(4)
any_nan = any(ieee_is_nan(self%values_4D_))
case default
error stop "NetCDF_variable_s(default_real_any_nan): unsupported rank)"
end select
end procedure

module procedure double_precision_any_nan

select case(self%rank())
case(1)
any_nan = any(ieee_is_nan(self%values_1D_))
case(2)
any_nan = any(ieee_is_nan(self%values_2D_))
case(3)
any_nan = any(ieee_is_nan(self%values_3D_))
case(4)
any_nan = any(ieee_is_nan(self%values_4D_))
case default
error stop "NetCDF_variable_s(double_precision_any_nan): unsupported rank)"
end select
end procedure

end submodule NetCDF_variable_s

0 comments on commit 0b0e16a

Please sign in to comment.