diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 48749c90..4771dd69 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -253,7 +253,7 @@ subroutine torch_tensor_empty(tensor, ndims, tensor_shape, dtype, & use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -304,7 +304,7 @@ subroutine torch_tensor_zeros(tensor, ndims, tensor_shape, dtype, & use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -355,7 +355,7 @@ subroutine torch_tensor_ones(tensor, ndims, tensor_shape, dtype, & use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -409,8 +409,8 @@ subroutine torch_tensor_from_blob(tensor, data, ndims, tensor_shape, layout, dty type(torch_tensor), intent(out) :: tensor !! Returned tensor type(c_ptr), intent(in) :: data !! Pointer to data integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor - integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor + integer(c_int), intent(in) :: layout(:) !! Layout for strides for accessing data integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -427,6 +427,7 @@ subroutine torch_tensor_from_blob(tensor, data, ndims, tensor_shape, layout, dty requires_grad_value = requires_grad end if + strides(:) = 0 do i = 1, ndims if (i == 1) then strides(layout(i)) = 1 diff --git a/src/ftorch.fypp b/src/ftorch.fypp index d3e060d8..1bf5e5dd 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -208,7 +208,7 @@ contains use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -259,7 +259,7 @@ contains use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -310,7 +310,7 @@ contains use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_int64_t type(torch_tensor), intent(out) :: tensor !! Returned tensor integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -364,8 +364,8 @@ contains type(torch_tensor), intent(out) :: tensor !! Returned tensor type(c_ptr), intent(in) :: data !! Pointer to data integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor - integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor - integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data + integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor + integer(c_int), intent(in) :: layout(:) !! Layout for strides for accessing data integer(c_int), intent(in) :: dtype !! Data type of the tensor integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer, optional, intent(in) :: device_index !! Device index to use for `torch_kCUDA` case @@ -382,6 +382,7 @@ contains requires_grad_value = requires_grad end if + strides(:) = 0 do i = 1, ndims if (i == 1) then strides(layout(i)) = 1