Skip to content

Commit

Permalink
refac(train-cloud): generalize tensor/tensor-range
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 18, 2024
1 parent 0b0e16a commit ba04e9f
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 91 deletions.
149 changes: 68 additions & 81 deletions demo/app/train-cloud-microphysics.F90
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ program train_cloud_microphysics
!! Internal dependencies:
use phase_space_bin_m, only : phase_space_bin_t
use NetCDF_file_m, only: NetCDF_file_t
use NetCDF_variable_m, only: NetCDF_variable_t
use NetCDF_variable_m, only: NetCDF_variable_t, tensors
implicit none

character(len=*), parameter :: usage = new_line('a') // new_line('a') // &
Expand All @@ -31,7 +31,7 @@ program train_cloud_microphysics
'The presence of a file named "stop" halts execution gracefully.'

type command_line_arguments_t
integer num_epochs, start_step, stride, num_bins, report_interval
integer num_epochs, start_step, stride, num_bins, report_step
integer, allocatable :: end_step
character(len=:), allocatable :: base_name
real cost_tolerance
Expand Down Expand Up @@ -105,7 +105,7 @@ function get_command_line_arguments() result(command_line_arguments)
base_name, epochs_string, start_string, end_string, stride_string, bins_string, report_string, tolerance_string
real cost_tolerance
integer, allocatable :: end_step
integer num_epochs, num_bins, start_step, stride, report_interval
integer num_epochs, num_bins, start_step, stride, report_step

base_name = command_line%flag_value("--base") ! gfortran 13 seg faults if this is an association
epochs_string = command_line%flag_value("--epochs")
Expand All @@ -122,54 +122,24 @@ function get_command_line_arguments() result(command_line_arguments)

read(epochs_string,*) num_epochs

if (len(stride_string)==0) then
stride = 1
else
read(stride_string,*) stride
end if

if (len(start_string)==0) then
start_step = 1
else
read(start_string,*) start_step
end if

if (len(report_string)==0) then
report_interval = 1
else
read(report_string,*) report_interval
end if

if (len(bins_string)/=0) then
read(bins_string,*) num_bins
else
num_bins = 1
end if
stride = default_integer_or_read(1, stride_string)
start_step = default_integer_or_read(1, start_string)
report_step = default_integer_or_read(1, report_string)
num_bins = default_integer_or_read(1, bins_string)
cost_tolerance = default_real_or_read(5E-8, tolerance_string)

if (len(end_string)/=0) then
allocate(end_step)
read(end_string,*) end_step
end if

if (len(start_string)==0) then
start_step = 1
else
read(start_string,*) start_step
end if

if (len(tolerance_string)==0) then
cost_tolerance = 5.0E-08
else
read(tolerance_string,*) cost_tolerance
end if

if (allocated(end_step)) then
command_line_arguments = command_line_arguments_t( &
num_epochs, start_step, stride, num_bins, report_interval, end_step, base_name, cost_tolerance &
num_epochs, start_step, stride, num_bins, report_step, end_step, base_name, cost_tolerance &
)
else
command_line_arguments = command_line_arguments_t( &
num_epochs, start_step, stride, num_bins, report_interval, null(), base_name, cost_tolerance &
num_epochs, start_step, stride, num_bins, report_step, null(), base_name, cost_tolerance &
)
end if

Expand All @@ -194,15 +164,17 @@ subroutine read_train_write(training_configuration, args, plot_file)
enumerator :: dpotential_temperature_t=1, dqv_dt, dqc_dt, dqr_dt, dqs_dt
end enum

associate(input_names => &
[string_t("pressure"), string_t("potential_temperature"), string_t("temperature"), &
string_t("qv"), string_t("qc"), string_t("qr"), string_t("qs")] &
)
!associate(input_names => &
! [string_t("pressure"), string_t("potential_temperature"), string_t("temperature"), &
! string_t("qv"), string_t("qc"), string_t("qr"), string_t("qs")] &
!)
associate(input_names => [string_t("qv"), string_t("qc")])

allocate(input_variable(size(input_names)))

associate(input_file_name => args%base_name // "_input.nc")

print *,"Reading network inputs from " // input_file_name
print *,"Reading physics-based model inputs from " // input_file_name

associate(input_file => netCDF_file_t(input_file_name))

Expand All @@ -222,13 +194,14 @@ subroutine read_train_write(training_configuration, args, plot_file)
end associate
end associate

associate(output_names => [string_t("potential_temperature"),string_t("qv"), string_t("qc"), string_t("qr"), string_t("qs")])
!associate(output_names => [string_t("potential_temperature"),string_t("qv"), string_t("qc"), string_t("qr"), string_t("qs")])
associate(output_names => [string_t("qv"), string_t("qc")])

allocate(output_variable(size(output_names)))

associate(output_file_name => args%base_name // "_output.nc")

print *,"Reading network outputs from " // output_file_name
print *,"Reading physics-based model outputs from " // output_file_name

associate(output_file => netCDF_file_t(output_file_name))

Expand All @@ -252,7 +225,7 @@ subroutine read_train_write(training_configuration, args, plot_file)
block
type(NetCDF_variable_t) derivative(size(output_variable))

print *,"Calculating time derivatives"
print *,"Calculating desired neural-network model outputs"

associate(dt => NetCDF_variable_t(output_time - input_time, "dt"))
do v = 1, size(derivative)
Expand All @@ -275,45 +248,33 @@ subroutine read_train_write(training_configuration, args, plot_file)
type(input_output_pair_t), allocatable :: input_output_pairs(:)
type(tensor_t), allocatable, dimension(:) :: inputs, outputs
real, allocatable :: cost(:)
integer i, lon, lat, level, time, network_unit, io_status, epoch, end_step
integer i, network_unit, io_status, epoch, end_step
integer(int64) start_training, finish_training

associate( network_file => args%base_name // "_network.json")

open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')

! if (allocated(args%end_step)) then
! end_step = args%end_step
! else
! end_step = t_end
! end if

! print *,"Defining tensors from time step", args%start_step, "through", end_step, "with strides of", args%stride
if (allocated(args%end_step)) then
end_step = args%end_step
else
end_step = input_variable(1)%end_step()
end if

! ! The following temporary copies are required by gfortran bug 100650 and possibly 49324
! ! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
! inputs = [( [( [( [( &
! tensor_t( &
! [ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
! qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
! ] &
! ), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], &
! time = args%start_step, end_step, args%stride)]

! outputs = [( [( [( [( &
! tensor_t( &
! [dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
! dqs_dt(lon,lat,level,time) &
! ] &
! ), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], &
! time = args%start_step, end_step, args%stride)]
print *,"Defining input tensors starting from time step", args%start_step, "through", end_step, "with strides of", args%stride
inputs = tensors(input_variable, step_start = args%start_step, step_end = end_step, step_stride = args%stride)

print *,"Defining output tensors starting from time step", args%start_step, "through", end_step, "with strides of", args%stride
outputs = tensors(output_variable, step_start = args%start_step, step_end = end_step, step_stride = args%stride)

! print *,"Calculating output tensor component ranges."
! output_extrema: &
! associate( &
! output_minima => [minval(dpt_dt), minval(dqv_dt), minval(dqc_dt), minval(dqr_dt), minval(dqs_dt)], &
! output_maxima => [maxval(dpt_dt), maxval(dqv_dt), maxval(dqc_dt), maxval(dqr_dt), maxval(dqs_dt)] &
! )
print *,"Calculating output tensor component ranges."
tensor_extrema: &
associate( &
input_minima => [( input_variable(v)%minimum(), v=1,size( input_variable) )] &
,input_maxima => [( input_variable(v)%maximum(), v=1,size( input_variable) )] &
,output_minima => [( output_variable(v)%minimum(), v=1,size(output_variable) )] &
,output_maxima => [( output_variable(v)%maximum(), v=1,size(output_variable) )] &
)
! output_map: &
! associate( output_map => tensor_map_t(layer = "outputs", minima = output_minima, maxima = output_maxima))
! read_or_initialize_network: &
Expand Down Expand Up @@ -376,7 +337,7 @@ subroutine read_train_write(training_configuration, args, plot_file)
! " in ", count(occupied)," out of ", size(occupied, kind=int64), " bins."
! end block
! end associate output_map
! end associate output_extrema
end associate tensor_extrema

! print *,"Normalizing the remaining input and output tensors"
! input_output_pairs = trainable_network%map_to_training_ranges(input_output_pairs)
Expand Down Expand Up @@ -423,7 +384,7 @@ subroutine read_train_write(training_configuration, args, plot_file)
! associate(converged => average_cost <= args%cost_tolerance)

! image_1_maybe_writes: &
! if (me==1 .and. any([converged, epoch==[first_epoch,last_epoch], mod(epoch,args%report_interval)==0])) then
! if (me==1 .and. any([converged, epoch==[first_epoch,last_epoch], mod(epoch,args%report_step)==0])) then

! print *, epoch, average_cost
! write(plot_file%plot_unit,*) epoch, average_cost
Expand Down Expand Up @@ -471,4 +432,30 @@ subroutine read_train_write(training_configuration, args, plot_file)

end subroutine read_train_write

pure function default_integer_or_read(default, string) result(set_value)
integer, intent(in) :: default
character(len=*), intent(in) :: string
integer set_value

if (len(string)==0) then
set_value = default
else
read(string,*) set_value
end if

end function

pure function default_real_or_read(default, string) result(set_value)
real, intent(in) :: default
character(len=*), intent(in) :: string
real set_value

if (len(string)==0) then
set_value = default
else
read(string,*) set_value
end if

end function

end program train_cloud_microphysics
83 changes: 73 additions & 10 deletions demo/src/NetCDF_variable_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ module NetCDF_variable_m
use NetCDF_file_m, only : NetCDF_file_t
use kind_parameters_m, only : default_real, double_precision
use julienne_m, only : string_t
use fiats_m, only : tensor_t
implicit none

private
public :: NetCDF_variable_t
public :: tensors

type NetCDF_variable_t(k)
integer, kind :: k = default_real
Expand All @@ -19,16 +21,22 @@ module NetCDF_variable_m
procedure, private, non_overridable :: default_real_input, double_precision_input, default_real_input_character_name, double_precision_input_character_name
generic :: conformable_with => default_real_conformable_with, double_precision_conformable_with
procedure, private, non_overridable :: default_real_conformable_with, double_precision_conformable_with
generic :: rank => default_real_rank, double_precision_rank
procedure, private, non_overridable :: default_real_rank, double_precision_rank
generic :: any_nan => default_real_any_nan, double_precision_any_nan
procedure, private, non_overridable :: default_real_any_nan, double_precision_any_nan
generic :: operator(-) => default_real_subtract, double_precision_subtract
procedure, private, non_overridable :: default_real_subtract, double_precision_subtract
generic :: operator(/) => default_real_divide, double_precision_divide
procedure, private, non_overridable :: default_real_divide, double_precision_divide
generic :: assignment(=) => default_real_assign, double_precision_assign
procedure, private, non_overridable :: default_real_assign, double_precision_assign
generic :: rank => default_real_rank , double_precision_rank
procedure, private, non_overridable :: default_real_rank , double_precision_rank
generic :: end_step => default_real_end_step , double_precision_end_step
procedure, private, non_overridable :: default_real_end_step , double_precision_end_step
generic :: any_nan => default_real_any_nan , double_precision_any_nan
procedure, private, non_overridable :: default_real_any_nan , double_precision_any_nan
generic :: minimum => default_real_minimum , double_precision_minimum
procedure, private, non_overridable :: default_real_minimum , double_precision_minimum
generic :: maximum => default_real_maximum , double_precision_maximum
procedure, private, non_overridable :: default_real_maximum , double_precision_maximum
generic :: operator(-) => default_real_subtract , double_precision_subtract
procedure, private, non_overridable :: default_real_subtract , double_precision_subtract
generic :: operator(/) => default_real_divide , double_precision_divide
procedure, private, non_overridable :: default_real_divide , double_precision_divide
generic :: assignment(=) => default_real_assign , double_precision_assign
procedure, private, non_overridable :: default_real_assign , double_precision_assign
end type

interface NetCDF_variable_t
Expand Down Expand Up @@ -121,6 +129,18 @@ elemental module function double_precision_rank(self) result(my_rank)
integer my_rank
end function

elemental module function default_real_end_step(self) result(end_step)
implicit none
class(NetCDF_variable_t), intent(inout) :: self
integer end_step
end function

elemental module function double_precision_end_step(self) result(end_step)
implicit none
class(NetCDF_variable_t(double_precision)), intent(inout) :: self
integer end_step
end function

elemental module function default_real_subtract(lhs, rhs) result(difference)
implicit none
class(NetCDF_variable_t), intent(in) :: lhs, rhs
Expand Down Expand Up @@ -169,6 +189,49 @@ elemental module function double_precision_any_nan(self) result(any_nan)
logical any_nan
end function

elemental module function default_real_minimum(self) result(minimum)
implicit none
class(NetCDF_variable_t), intent(in) :: self
real minimum
end function

elemental module function double_precision_minimum(self) result(minimum)
implicit none
class(NetCDF_variable_t(double_precision)), intent(in) :: self
real minimum
end function

elemental module function default_real_maximum(self) result(maximum)
implicit none
class(NetCDF_variable_t), intent(in) :: self
real maximum
end function

elemental module function double_precision_maximum(self) result(maximum)
implicit none
class(NetCDF_variable_t(double_precision)), intent(in) :: self
real maximum
end function

module function tensors(NetCDF_variables, step_start, step_end, step_stride)
implicit none
type(NetCDF_variable_t), intent(in) :: NetCDF_variables(:)
type(tensor_t), allocatable :: tensors(:)
integer, optional :: step_start, step_end, step_stride
end function

elemental module function default_real_end_time(self) result(end_time)
implicit none
class(NetCDF_variable_t), intent(inout) :: self
integer end_time
end function

elemental module function double_precision_end_time(self) result(end_time)
implicit none
class(NetCDF_variable_t), intent(inout) :: self
integer end_time
end function

end interface

end module
Loading

0 comments on commit ba04e9f

Please sign in to comment.