diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index 42f4111f..6b95c6fb 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -4,8 +4,9 @@ program example use, intrinsic :: iso_fortran_env, only : sp => real32 ! Import our library for interfacing with PyTorch's Autograd module - use ftorch, only : torch_tensor, torch_kCPU, & - torch_tensor_from_array, torch_tensor_to_array, torch_tensor_delete + use ftorch, only: assignment(=), operator(+), operator(-), operator(*), & + operator(/), operator(**), torch_kCPU, torch_tensor, torch_tensor_delete, & + torch_tensor_from_array, torch_tensor_to_array ! Import our tools module for testing utils use ftorch_test_utils, only : assert_allclose @@ -16,8 +17,9 @@ program example integer, parameter :: wp = sp ! Set up Fortran data structures - integer, parameter :: n=2, m=5 - real(wp), dimension(n,m), target :: in_data + integer, parameter :: n=2, m=1 + real(wp), dimension(n,m), target :: in_data1 + real(wp), dimension(n,m), target :: in_data2 real(wp), dimension(:,:), pointer :: out_data real(wp), dimension(n,m) :: expected integer :: tensor_layout(2) = [1, 2] @@ -27,45 +29,78 @@ program example logical :: test_pass ! Set up Torch data structures - type(torch_tensor) :: tensor + type(torch_tensor) :: a, b, Q - ! initialize in_data with some fake data - do j = 1, m - do i = 1, n - in_data(i,j) = ((i-1)*m + j) * 1.0_wp - end do - end do + ! Initialise input arrays as in Python example + in_data1(:,1) = [2.0_wp, 3.0_wp] + in_data2(:,1) = [6.0_wp, 4.0_wp] ! Construct a Torch Tensor from a Fortran array - call torch_tensor_from_array(tensor, in_data, tensor_layout, torch_kCPU) + ! TODO: Implement requires_grad=.true. + call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU) + call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU) ! check tensor rank and shape match those of in_data - if (tensor%get_rank() /= 2) then + if ((a%get_rank() /= 2) .or. (b%get_rank() /= 2)) then print *, "Error :: rank should be 2" stop 1 end if - if (any(tensor%get_shape() /= [2, 5])) then - print *, "Error :: shape should be (2, 5)" + if (any(a%get_shape() /= [n, m]) .or. any(b%get_shape() /= [n, m])) then + write(6,"('Error :: shape should be (',i1,', ',i1,')')") n, m stop 1 end if + ! Check arithmetic operations work for torch_tensors + write (*,*) "a = ", in_data1(:,1) + write (*,*) "b = ", in_data2(:,1) + Q = 3 * (a**3 - b * b / 3) + ! Extract a Fortran array from a Torch tensor - call torch_tensor_to_array(tensor, out_data, shape(in_data)) + call torch_tensor_to_array(Q, out_data, shape(in_data1)) + write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:,1) ! Check output tensor matches expected value - expected(:,:) = in_data + expected(:,1) = [-12.0_wp, 65.0_wp] test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5) + if (.not. test_pass) then + call clean_up() + print *, "Error :: out_data does not match expected value" + stop 999 + end if - ! Check that the data match + ! Check first input array is unchanged by the arithmetic operations + expected(:,1) = [2.0_wp, 3.0_wp] + test_pass = assert_allclose(in_data1, expected, test_name="torch_tensor_to_array", rtol=1e-5) if (.not. test_pass) then - print *, "Error :: in_data does not match out_data" + call clean_up() + print *, "Error :: in_data1 was changed during arithmetic operations" stop 999 end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + ! Check second input array is unchanged by the arithmetic operations + expected(:,1) = [6.0_wp, 4.0_wp] + test_pass = assert_allclose(in_data2, expected, test_name="torch_tensor_to_array", rtol=1e-5) + if (.not. test_pass) then + call clean_up() + print *, "Error :: in_data2 was changed during arithmetic operations" + stop 999 + end if + + ! Back-propagation + ! TODO: Requires API extension + ! Cleanup + call clean_up() write (*,*) "Autograd example ran successfully" + contains + + ! Subroutine for freeing memory and nullifying pointers used in the example + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(b) + call torch_tensor_delete(Q) + end subroutine clean_up + end program example diff --git a/examples/6_Autograd/autograd.py b/examples/6_Autograd/autograd.py index cc7ee753..9fdd816c 100644 --- a/examples/6_Autograd/autograd.py +++ b/examples/6_Autograd/autograd.py @@ -5,7 +5,7 @@ a = torch.tensor([2.0, 3.0], requires_grad=True) b = torch.tensor([6.0, 4.0], requires_grad=True) -Q = 3 * a**3 - b**2 +Q = 3 * (a**3 - b * b / 3) print(Q) expect = torch.tensor([-12.0, 65.0]) if not torch.allclose(Q, expect): diff --git a/pages/autograd.md b/pages/autograd.md new file mode 100644 index 00000000..92fce98e --- /dev/null +++ b/pages/autograd.md @@ -0,0 +1,42 @@ +title: Online training + +[TOC] + +## Current state + +FTorch has supported offline training of ML models for some time. We are +currently working on extending its functionality to support online training, +too. This will involve exposing the automatic differentiation and +back-propagation functionality in PyTorch/LibTorch. + +In the following, we document a workplan of the related functionality. Each step +below will be updated upon completion. + +### Operator overloading + +Mathematical operators involving Tensors are overloaded, so that we can compute +expressions involving outputs from one or more ML models. + +Whilst it's possible to import such functionality with a bare +```fortran +use ftorch +``` +statement, the best practice is to import specifically the operators that you +wish to use. Note that the assignment operator `=` has a slightly different +notation: +``` +use ftorch, only: assignment(=), operator(+), operator(-), operator(*), & + operator(/), operator(**) +``` + +For a concrete example of how to compute mathematical expressions involving +Torch tensors, see the associated +[worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/6_Autograd). + +### The `requires_grad` property + +*Not yet implemented.* + +### The `backward` operator + +*Not yet implemented.* diff --git a/pages/developer.md b/pages/developer.md index 16735f68..7b3f710c 100644 --- a/pages/developer.md +++ b/pages/developer.md @@ -77,6 +77,15 @@ and many of our users wish to _"clone-and-go"_ rather than develop, we provide b Development should only take place in `ftorch.fypp`, however._ +### Torch C++ API + +When extending or modifying functionality related to C++ header and/or source +files `src/ctorch.h` and `src/ctorch.cpp`, we refer to the Torch +[C++ documentation](https://pytorch.org/cppdocs) and more specifically the +[C++ API documentation](https://pytorch.org/cppdocs/api/library_root.html) +pages on the PyTorch website for details. + + ### git hook In order to streamline the process of uploading we provide a pre-commit hook in diff --git a/pages/examples.md b/pages/examples.md index 9d774705..dcbaf9d4 100644 --- a/pages/examples.md +++ b/pages/examples.md @@ -187,9 +187,16 @@ data to multiple GPU devices. considers a variant of the SimpleNet demo, which demonstrates how to account for multiple input tensors and multiple output tensors. -#### 7) Autograd +#### 5) Looping -[This worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/5_Autograd) +[This worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/5_Looping) +demonstrates best practices for performing inference on the same network with +different input multiple times in the same workflow. + +#### 6) Autograd + +[This worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/6_Autograd) is currently under development. Eventually, it will demonstrate how to perform automatic differentiation in FTorch by leveraging PyTorch's Autograd module. -Currently, it just demonstrates how to use `torch_tensor_to_array`. +Currently, it just demonstrates how to use `torch_tensor_to_array` and compute +mathematical expressions involving Torch tensors. diff --git a/run_integration_tests.sh b/run_integration_tests.sh index 94b27e8a..6afeb2ae 100755 --- a/run_integration_tests.sh +++ b/run_integration_tests.sh @@ -11,7 +11,12 @@ set -eu -EXAMPLES="1_SimpleNet 2_ResNet18 4_MultiIO 6_Autograd" +EXAMPLES=" + 1_SimpleNet + 2_ResNet18 + 4_MultiIO + 6_Autograd +" BUILD_DIR=src/build for EXAMPLE in ${EXAMPLES}; do diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9ee7e2bb..6c00aceb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -116,12 +116,12 @@ if(CMAKE_BUILD_TESTS) DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/2_ResNet18 DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) - # file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/3_MultiGPU DESTINATION - # ${CMAKE_CURRENT_SOURCE_DIR}/test/examples ) + # file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/3_MultiGPU + # DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/4_MultiIO DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) - # file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/5_Looping DESTINATION - # ${CMAKE_CURRENT_SOURCE_DIR}/test/examples ) + # file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/5_Looping + # DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/../examples/6_Autograd DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/test/examples) add_subdirectory(test/examples) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 1f3d8892..50f77a39 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -1,3 +1,8 @@ +/* + * For more details on the Torch Tensor C++ API, we refer to the Torch C++ documentation + * (https://pytorch.org/cppdocs) and more specifically the C++ API documentation + * (https://pytorch.org/cppdocs/api/library_root.html) pages on the PyTorch website. + */ #include #include @@ -233,6 +238,91 @@ void torch_tensor_delete(torch_tensor_t tensor) { delete t; } +torch_tensor_t torch_tensor_assign(const torch_tensor_t input) { + auto in = reinterpret_cast(input); + torch::AutoGradMode enable_grad(in->requires_grad()); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = in->detach().clone(); + return output; +} + +torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1, + const torch_tensor_t tensor2) { + auto t1 = reinterpret_cast(tensor1); + auto t2 = reinterpret_cast(tensor2); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t1 + *t2; + return output; +} + +torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1, + const torch_tensor_t tensor2) { + auto t1 = reinterpret_cast(tensor1); + auto t2 = reinterpret_cast(tensor2); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t1 - *t2; + return output; +} + +torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1, + const torch_tensor_t tensor2) { + auto t1 = reinterpret_cast(tensor1); + auto t2 = reinterpret_cast(tensor2); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t1 * *t2; + return output; +} + +torch_tensor_t torch_tensor_premultiply(const torch_data_t scalar, + const torch_tensor_t tensor) { + auto t = reinterpret_cast(tensor); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = scalar * *t; + return output; +} + +torch_tensor_t torch_tensor_postmultiply(const torch_tensor_t tensor, + const torch_data_t scalar) { + auto t = reinterpret_cast(tensor); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t * scalar; + return output; +} + +torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1, + const torch_tensor_t tensor2) { + auto t1 = reinterpret_cast(tensor1); + auto t2 = reinterpret_cast(tensor2); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t1 / *t2; + return output; +} + +torch_tensor_t torch_tensor_postdivide(const torch_tensor_t tensor, + const torch_data_t scalar) { + auto t = reinterpret_cast(tensor); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = *t / scalar; + return output; +} + +torch_tensor_t torch_tensor_power(const torch_tensor_t tensor, + const torch_data_t exponent) { + auto t = reinterpret_cast(tensor); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = pow(*t, exponent); + return output; +} + torch_jit_script_module_t torch_jit_load(const char *filename, const torch_device_t device_type = torch_kCPU, const int device_index = -1, diff --git a/src/ctorch.h b/src/ctorch.h index 0b5532e5..a4a6310a 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -136,6 +136,85 @@ EXPORT_C const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor */ EXPORT_C void torch_tensor_delete(torch_tensor_t tensor); +/** + * Overloads the assignment operator for Torch Tensor + * @param input Tensor + * @return copy of input Tensor + */ +EXPORT_C torch_tensor_t torch_tensor_assign(const torch_tensor_t input); + +/** + * Overloads the addition operator for two Torch Tensors + * @param first Tensor to be added + * @param second Tensor to be added + * @return sum of the Tensors + */ +EXPORT_C torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1, + const torch_tensor_t tensor2); + +/** + * Overloads the subtraction operator for two Torch Tensors + * @param first Tensor to be subtracted + * @param second Tensor to be subtracted + * @return difference of the Tensors + */ +EXPORT_C torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1, + const torch_tensor_t tensor2); + +/** + * Overloads the multiplication operator for two Torch Tensors + * @param first Tensor to be multiplied + * @param second Tensor to be multiplied + * @return product of the Tensors + */ +EXPORT_C torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1, + const torch_tensor_t tensor2); + +/** + * Overloads the premultiplication operator for a scalar and a Torch Tensor + * @param scalar to multiply by + * @param Tensor to be multiplied + * @return product of the scalar and Tensor + */ +EXPORT_C torch_tensor_t torch_tensor_premultiply(const torch_data_t scalar, + const torch_tensor_t tensor); + +/** + * Overloads the postmultiplication operator for a Torch Tensor and a scalar + * @param Tensor to be multiplied + * @param scalar to multiply by + * @return product of the Tensor and scalar + */ +EXPORT_C torch_tensor_t torch_tensor_postmultiply(const torch_tensor_t tensor, + const torch_data_t scalar); + +/** + * Overloads the division operator for two Torch Tensors + * @param first Tensor to be divided + * @param second Tensor to be divided + * @return quotient of the Tensors + */ +EXPORT_C torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1, + const torch_tensor_t tensor2); + +/** + * Overloads the post-division operator for a Torch Tensor and a scalar + * @param Tensor to be divided + * @param scalar to divide by + * @return quotient of the Tensor and scalar + */ +EXPORT_C torch_tensor_t torch_tensor_postdivide(const torch_tensor_t tensor, + const torch_data_t scalar); + +/** + * Overloads the exponentiation operator for two Torch Tensors + * @param Tensor to take the power of + * @param scalar exponent + * @return power of the Tensor + */ +EXPORT_C torch_tensor_t torch_tensor_power(const torch_tensor_t tensor, + const torch_data_t exponent); + // ===================================================================================== // Module API // ===================================================================================== diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 13b662a9..80884d56 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -148,6 +148,53 @@ function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, & end function torch_from_blob_c end interface + interface assignment (=) + module procedure torch_tensor_assign + end interface + + interface operator (+) + module procedure torch_tensor_add + end interface + + interface operator (-) + module procedure torch_tensor_subtract + end interface + + interface operator (*) + module procedure torch_tensor_multiply + module procedure torch_tensor_premultiply_int8 + module procedure torch_tensor_postmultiply_int8 + module procedure torch_tensor_premultiply_int16 + module procedure torch_tensor_postmultiply_int16 + module procedure torch_tensor_premultiply_int32 + module procedure torch_tensor_postmultiply_int32 + module procedure torch_tensor_premultiply_int64 + module procedure torch_tensor_postmultiply_int64 + module procedure torch_tensor_premultiply_real32 + module procedure torch_tensor_postmultiply_real32 + module procedure torch_tensor_premultiply_real64 + module procedure torch_tensor_postmultiply_real64 + end interface + + interface operator (/) + module procedure torch_tensor_divide + module procedure torch_tensor_postdivide_int8 + module procedure torch_tensor_postdivide_int16 + module procedure torch_tensor_postdivide_int32 + module procedure torch_tensor_postdivide_int64 + module procedure torch_tensor_postdivide_real32 + module procedure torch_tensor_postdivide_real64 + end interface + + interface operator (**) + module procedure torch_tensor_power_int8 + module procedure torch_tensor_power_int16 + module procedure torch_tensor_power_int32 + module procedure torch_tensor_power_int64 + module procedure torch_tensor_power_real32 + module procedure torch_tensor_power_real64 + end interface + interface function torch_to_blob_c(tensor, dtype) result(data) & bind(c, name = 'torch_to_blob') @@ -419,6 +466,611 @@ end subroutine torch_tensor_delete_c call torch_tensor_delete_c(tensor%p) end subroutine torch_tensor_delete + !> Overloads assignment operator for tensors. + subroutine torch_tensor_assign(output, input) + type(torch_tensor), intent(out) :: output + type(torch_tensor), intent(in) :: input + + interface + function torch_tensor_assign_c(input_c) result(output_c) & + bind(c, name = 'torch_tensor_assign') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: input_c + type(c_ptr) :: output_c + end function torch_tensor_assign_c + end interface + + output%p = torch_tensor_assign_c(input%p) + end subroutine torch_tensor_assign + + !> Overloads addition operator for two tensors. + function torch_tensor_add(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_add_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_add') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_add_c + end interface + + output%p = torch_tensor_add_c(tensor1%p, tensor2%p) + end function torch_tensor_add + + !> Overloads subtraction operator for two tensors. + function torch_tensor_subtract(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_subtract_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_subtract') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_subtract_c + end interface + + output%p = torch_tensor_subtract_c(tensor1%p, tensor2%p) + end function torch_tensor_subtract + + !> Overloads multiplication operator for two tensors. + function torch_tensor_multiply(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_multiply_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_multiply') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_multiply_c + end interface + + output%p = torch_tensor_multiply_c(tensor1%p, tensor2%p) + end function torch_tensor_multiply + + !> Overloads multiplication operator for a scalar of type int8 and a tensor. + function torch_tensor_premultiply_int8(scalar, tensor) result(output) + integer(kind=int8), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int8 + implicit none + integer(kind=int8), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_int8 + + !> Overloads multiplication operator for a scalar of type int16 and a tensor. + function torch_tensor_premultiply_int16(scalar, tensor) result(output) + integer(kind=int16), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int16 + implicit none + integer(kind=int16), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_int16 + + !> Overloads multiplication operator for a scalar of type int32 and a tensor. + function torch_tensor_premultiply_int32(scalar, tensor) result(output) + integer(kind=int32), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int32 + implicit none + integer(kind=int32), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_int32 + + !> Overloads multiplication operator for a scalar of type int64 and a tensor. + function torch_tensor_premultiply_int64(scalar, tensor) result(output) + integer(kind=int64), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int64 + implicit none + integer(kind=int64), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_int64 + + !> Overloads multiplication operator for a scalar of type real32 and a tensor. + function torch_tensor_premultiply_real32(scalar, tensor) result(output) + real(kind=real32), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real32 + implicit none + real(kind=real32), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_real32 + + !> Overloads multiplication operator for a scalar of type real64 and a tensor. + function torch_tensor_premultiply_real64(scalar, tensor) result(output) + real(kind=real64), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real64 + implicit none + real(kind=real64), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_real64 + + + !> Overloads multiplication operator for a tensor and a scalar of type int8. + function torch_tensor_postmultiply_int8(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int8), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int8 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int8), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_int8 + + !> Overloads multiplication operator for a tensor and a scalar of type int16. + function torch_tensor_postmultiply_int16(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int16), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int16 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int16), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_int16 + + !> Overloads multiplication operator for a tensor and a scalar of type int32. + function torch_tensor_postmultiply_int32(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int32), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int32), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_int32 + + !> Overloads multiplication operator for a tensor and a scalar of type int64. + function torch_tensor_postmultiply_int64(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int64), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int64), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_int64 + + !> Overloads multiplication operator for a tensor and a scalar of type real32. + function torch_tensor_postmultiply_real32(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real32), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real32), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_real32 + + !> Overloads multiplication operator for a tensor and a scalar of type real64. + function torch_tensor_postmultiply_real64(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real64), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real64), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_real64 + + !> Overloads division operator for two tensors. + function torch_tensor_divide(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_divide_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_divide') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_divide_c + end interface + + output%p = torch_tensor_divide_c(tensor1%p, tensor2%p) + end function torch_tensor_divide + + !> Overloads division operator for a tensor and a scalar of type int8. + function torch_tensor_postdivide_int8(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int8), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int8 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int8), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_int8 + + !> Overloads division operator for a tensor and a scalar of type int16. + function torch_tensor_postdivide_int16(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int16), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int16 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int16), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_int16 + + !> Overloads division operator for a tensor and a scalar of type int32. + function torch_tensor_postdivide_int32(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int32), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int32), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_int32 + + !> Overloads division operator for a tensor and a scalar of type int64. + function torch_tensor_postdivide_int64(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int64), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int64), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_int64 + + !> Overloads division operator for a tensor and a scalar of type real32. + function torch_tensor_postdivide_real32(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real32), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real32), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_real32 + + !> Overloads division operator for a tensor and a scalar of type real64. + function torch_tensor_postdivide_real64(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real64), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real64), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_real64 + + + !> Overloads exponentiation operator for a tensor and a scalar of type `int8` + function torch_tensor_power_int8(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int8), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int8 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int8), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_int8 + + !> Overloads exponentiation operator for a tensor and a scalar of type `int16` + function torch_tensor_power_int16(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int16), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int16 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int16), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_int16 + + !> Overloads exponentiation operator for a tensor and a scalar of type `int32` + function torch_tensor_power_int32(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int32), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int32), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_int32 + + !> Overloads exponentiation operator for a tensor and a scalar of type `int64` + function torch_tensor_power_int64(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + integer(kind=int64), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : int64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + integer(kind=int64), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_int64 + + !> Overloads exponentiation operator for a tensor and a scalar of type `real32` + function torch_tensor_power_real32(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real32), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real32 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real32), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_real32 + + !> Overloads exponentiation operator for a tensor and a scalar of type `real64` + function torch_tensor_power_real64(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + real(kind=real64), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : real64 + implicit none + type(c_ptr), value, intent(in) :: tensor_c + real(kind=real64), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_real64 + + ! Torch Model API !> Loads a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript) subroutine torch_model_load(model, filename, device_type, device_index, & @@ -555,7 +1207,7 @@ end subroutine torch_model_delete !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int8` subroutine torch_tensor_from_array_int8_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int8 @@ -565,15 +1217,15 @@ subroutine torch_tensor_from_array_int8_1d(tensor, data_in, layout, & ! inputs integer(kind=int8), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -581,7 +1233,7 @@ subroutine torch_tensor_from_array_int8_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -593,15 +1245,15 @@ subroutine torch_tensor_from_array_int8_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -609,7 +1261,7 @@ end subroutine torch_tensor_from_array_int8_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int8` subroutine torch_tensor_from_array_int8_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int8 @@ -619,15 +1271,15 @@ subroutine torch_tensor_from_array_int8_2d(tensor, data_in, layout, & ! inputs integer(kind=int8), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -635,7 +1287,7 @@ subroutine torch_tensor_from_array_int8_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -647,15 +1299,15 @@ subroutine torch_tensor_from_array_int8_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -663,7 +1315,7 @@ end subroutine torch_tensor_from_array_int8_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int8` subroutine torch_tensor_from_array_int8_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int8 @@ -673,15 +1325,15 @@ subroutine torch_tensor_from_array_int8_3d(tensor, data_in, layout, & ! inputs integer(kind=int8), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -689,7 +1341,7 @@ subroutine torch_tensor_from_array_int8_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -701,15 +1353,15 @@ subroutine torch_tensor_from_array_int8_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -717,7 +1369,7 @@ end subroutine torch_tensor_from_array_int8_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int8` subroutine torch_tensor_from_array_int8_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int8 @@ -727,15 +1379,15 @@ subroutine torch_tensor_from_array_int8_4d(tensor, data_in, layout, & ! inputs integer(kind=int8), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -743,7 +1395,7 @@ subroutine torch_tensor_from_array_int8_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -755,15 +1407,15 @@ subroutine torch_tensor_from_array_int8_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -771,7 +1423,7 @@ end subroutine torch_tensor_from_array_int8_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `int8` subroutine torch_tensor_from_array_int8_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int8 @@ -781,15 +1433,15 @@ subroutine torch_tensor_from_array_int8_5d(tensor, data_in, layout, & ! inputs integer(kind=int8), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -797,7 +1449,7 @@ subroutine torch_tensor_from_array_int8_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -809,15 +1461,15 @@ subroutine torch_tensor_from_array_int8_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -825,7 +1477,7 @@ end subroutine torch_tensor_from_array_int8_5d !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int16` subroutine torch_tensor_from_array_int16_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int16 @@ -835,15 +1487,15 @@ subroutine torch_tensor_from_array_int16_1d(tensor, data_in, layout, & ! inputs integer(kind=int16), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -851,7 +1503,7 @@ subroutine torch_tensor_from_array_int16_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -863,15 +1515,15 @@ subroutine torch_tensor_from_array_int16_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -879,7 +1531,7 @@ end subroutine torch_tensor_from_array_int16_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int16` subroutine torch_tensor_from_array_int16_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int16 @@ -889,15 +1541,15 @@ subroutine torch_tensor_from_array_int16_2d(tensor, data_in, layout, & ! inputs integer(kind=int16), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -905,7 +1557,7 @@ subroutine torch_tensor_from_array_int16_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -917,15 +1569,15 @@ subroutine torch_tensor_from_array_int16_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -933,7 +1585,7 @@ end subroutine torch_tensor_from_array_int16_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int16` subroutine torch_tensor_from_array_int16_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int16 @@ -943,15 +1595,15 @@ subroutine torch_tensor_from_array_int16_3d(tensor, data_in, layout, & ! inputs integer(kind=int16), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -959,7 +1611,7 @@ subroutine torch_tensor_from_array_int16_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -971,15 +1623,15 @@ subroutine torch_tensor_from_array_int16_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -987,7 +1639,7 @@ end subroutine torch_tensor_from_array_int16_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int16` subroutine torch_tensor_from_array_int16_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int16 @@ -997,15 +1649,15 @@ subroutine torch_tensor_from_array_int16_4d(tensor, data_in, layout, & ! inputs integer(kind=int16), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1013,7 +1665,7 @@ subroutine torch_tensor_from_array_int16_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1025,15 +1677,15 @@ subroutine torch_tensor_from_array_int16_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1041,7 +1693,7 @@ end subroutine torch_tensor_from_array_int16_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `int16` subroutine torch_tensor_from_array_int16_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int16 @@ -1051,15 +1703,15 @@ subroutine torch_tensor_from_array_int16_5d(tensor, data_in, layout, & ! inputs integer(kind=int16), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1067,7 +1719,7 @@ subroutine torch_tensor_from_array_int16_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1079,15 +1731,15 @@ subroutine torch_tensor_from_array_int16_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1095,7 +1747,7 @@ end subroutine torch_tensor_from_array_int16_5d !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int32` subroutine torch_tensor_from_array_int32_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int32 @@ -1105,15 +1757,15 @@ subroutine torch_tensor_from_array_int32_1d(tensor, data_in, layout, & ! inputs integer(kind=int32), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1121,7 +1773,7 @@ subroutine torch_tensor_from_array_int32_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1133,15 +1785,15 @@ subroutine torch_tensor_from_array_int32_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1149,7 +1801,7 @@ end subroutine torch_tensor_from_array_int32_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int32` subroutine torch_tensor_from_array_int32_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int32 @@ -1159,15 +1811,15 @@ subroutine torch_tensor_from_array_int32_2d(tensor, data_in, layout, & ! inputs integer(kind=int32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1175,7 +1827,7 @@ subroutine torch_tensor_from_array_int32_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1187,15 +1839,15 @@ subroutine torch_tensor_from_array_int32_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1203,7 +1855,7 @@ end subroutine torch_tensor_from_array_int32_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int32` subroutine torch_tensor_from_array_int32_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int32 @@ -1213,15 +1865,15 @@ subroutine torch_tensor_from_array_int32_3d(tensor, data_in, layout, & ! inputs integer(kind=int32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1229,7 +1881,7 @@ subroutine torch_tensor_from_array_int32_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1241,15 +1893,15 @@ subroutine torch_tensor_from_array_int32_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1257,7 +1909,7 @@ end subroutine torch_tensor_from_array_int32_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int32` subroutine torch_tensor_from_array_int32_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int32 @@ -1267,15 +1919,15 @@ subroutine torch_tensor_from_array_int32_4d(tensor, data_in, layout, & ! inputs integer(kind=int32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1283,7 +1935,7 @@ subroutine torch_tensor_from_array_int32_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1295,15 +1947,15 @@ subroutine torch_tensor_from_array_int32_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1311,7 +1963,7 @@ end subroutine torch_tensor_from_array_int32_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `int32` subroutine torch_tensor_from_array_int32_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int32 @@ -1321,15 +1973,15 @@ subroutine torch_tensor_from_array_int32_5d(tensor, data_in, layout, & ! inputs integer(kind=int32), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1337,7 +1989,7 @@ subroutine torch_tensor_from_array_int32_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1349,15 +2001,15 @@ subroutine torch_tensor_from_array_int32_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1365,7 +2017,7 @@ end subroutine torch_tensor_from_array_int32_5d !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int64` subroutine torch_tensor_from_array_int64_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int64 @@ -1375,15 +2027,15 @@ subroutine torch_tensor_from_array_int64_1d(tensor, data_in, layout, & ! inputs integer(kind=int64), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1391,7 +2043,7 @@ subroutine torch_tensor_from_array_int64_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1403,15 +2055,15 @@ subroutine torch_tensor_from_array_int64_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1419,7 +2071,7 @@ end subroutine torch_tensor_from_array_int64_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int64` subroutine torch_tensor_from_array_int64_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int64 @@ -1429,15 +2081,15 @@ subroutine torch_tensor_from_array_int64_2d(tensor, data_in, layout, & ! inputs integer(kind=int64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1445,7 +2097,7 @@ subroutine torch_tensor_from_array_int64_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1457,15 +2109,15 @@ subroutine torch_tensor_from_array_int64_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1473,7 +2125,7 @@ end subroutine torch_tensor_from_array_int64_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int64` subroutine torch_tensor_from_array_int64_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int64 @@ -1483,15 +2135,15 @@ subroutine torch_tensor_from_array_int64_3d(tensor, data_in, layout, & ! inputs integer(kind=int64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1499,7 +2151,7 @@ subroutine torch_tensor_from_array_int64_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1511,15 +2163,15 @@ subroutine torch_tensor_from_array_int64_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1527,7 +2179,7 @@ end subroutine torch_tensor_from_array_int64_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int64` subroutine torch_tensor_from_array_int64_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int64 @@ -1537,15 +2189,15 @@ subroutine torch_tensor_from_array_int64_4d(tensor, data_in, layout, & ! inputs integer(kind=int64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1553,7 +2205,7 @@ subroutine torch_tensor_from_array_int64_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1565,15 +2217,15 @@ subroutine torch_tensor_from_array_int64_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1581,7 +2233,7 @@ end subroutine torch_tensor_from_array_int64_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `int64` subroutine torch_tensor_from_array_int64_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : int64 @@ -1591,15 +2243,15 @@ subroutine torch_tensor_from_array_int64_5d(tensor, data_in, layout, & ! inputs integer(kind=int64), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1607,7 +2259,7 @@ subroutine torch_tensor_from_array_int64_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1619,15 +2271,15 @@ subroutine torch_tensor_from_array_int64_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1635,7 +2287,7 @@ end subroutine torch_tensor_from_array_int64_5d !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real32` subroutine torch_tensor_from_array_real32_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real32 @@ -1645,15 +2297,15 @@ subroutine torch_tensor_from_array_real32_1d(tensor, data_in, layout, & ! inputs real(kind=real32), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1661,7 +2313,7 @@ subroutine torch_tensor_from_array_real32_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1673,15 +2325,15 @@ subroutine torch_tensor_from_array_real32_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1689,7 +2341,7 @@ end subroutine torch_tensor_from_array_real32_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real32` subroutine torch_tensor_from_array_real32_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real32 @@ -1699,15 +2351,15 @@ subroutine torch_tensor_from_array_real32_2d(tensor, data_in, layout, & ! inputs real(kind=real32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1715,7 +2367,7 @@ subroutine torch_tensor_from_array_real32_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1727,15 +2379,15 @@ subroutine torch_tensor_from_array_real32_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1743,7 +2395,7 @@ end subroutine torch_tensor_from_array_real32_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real32` subroutine torch_tensor_from_array_real32_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real32 @@ -1753,15 +2405,15 @@ subroutine torch_tensor_from_array_real32_3d(tensor, data_in, layout, & ! inputs real(kind=real32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1769,7 +2421,7 @@ subroutine torch_tensor_from_array_real32_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1781,15 +2433,15 @@ subroutine torch_tensor_from_array_real32_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1797,7 +2449,7 @@ end subroutine torch_tensor_from_array_real32_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real32` subroutine torch_tensor_from_array_real32_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real32 @@ -1807,15 +2459,15 @@ subroutine torch_tensor_from_array_real32_4d(tensor, data_in, layout, & ! inputs real(kind=real32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1823,7 +2475,7 @@ subroutine torch_tensor_from_array_real32_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1835,15 +2487,15 @@ subroutine torch_tensor_from_array_real32_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1851,7 +2503,7 @@ end subroutine torch_tensor_from_array_real32_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `real32` subroutine torch_tensor_from_array_real32_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real32 @@ -1861,15 +2513,15 @@ subroutine torch_tensor_from_array_real32_5d(tensor, data_in, layout, & ! inputs real(kind=real32), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1877,7 +2529,7 @@ subroutine torch_tensor_from_array_real32_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1889,15 +2541,15 @@ subroutine torch_tensor_from_array_real32_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1905,7 +2557,7 @@ end subroutine torch_tensor_from_array_real32_5d !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real64` subroutine torch_tensor_from_array_real64_1d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real64 @@ -1915,15 +2567,15 @@ subroutine torch_tensor_from_array_real64_1d(tensor, data_in, layout, & ! inputs real(kind=real64), intent(in), target :: data_in(:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(1) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type - integer(c_int64_t) :: strides(1) !! Strides for accessing data - integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(1) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(1) !! Strides for accessing data + integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1931,7 +2583,7 @@ subroutine torch_tensor_from_array_real64_1d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1943,15 +2595,15 @@ subroutine torch_tensor_from_array_real64_1d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -1959,7 +2611,7 @@ end subroutine torch_tensor_from_array_real64_1d !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real64` subroutine torch_tensor_from_array_real64_2d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real64 @@ -1969,15 +2621,15 @@ subroutine torch_tensor_from_array_real64_2d(tensor, data_in, layout, & ! inputs real(kind=real64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(2) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type - integer(c_int64_t) :: strides(2) !! Strides for accessing data - integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(2) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(2) !! Strides for accessing data + integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -1985,7 +2637,7 @@ subroutine torch_tensor_from_array_real64_2d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -1997,15 +2649,15 @@ subroutine torch_tensor_from_array_real64_2d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -2013,7 +2665,7 @@ end subroutine torch_tensor_from_array_real64_2d !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real64` subroutine torch_tensor_from_array_real64_3d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real64 @@ -2023,15 +2675,15 @@ subroutine torch_tensor_from_array_real64_3d(tensor, data_in, layout, & ! inputs real(kind=real64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(3) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type - integer(c_int64_t) :: strides(3) !! Strides for accessing data - integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(3) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(3) !! Strides for accessing data + integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -2039,7 +2691,7 @@ subroutine torch_tensor_from_array_real64_3d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -2051,15 +2703,15 @@ subroutine torch_tensor_from_array_real64_3d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -2067,7 +2719,7 @@ end subroutine torch_tensor_from_array_real64_3d !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real64` subroutine torch_tensor_from_array_real64_4d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real64 @@ -2077,15 +2729,15 @@ subroutine torch_tensor_from_array_real64_4d(tensor, data_in, layout, & ! inputs real(kind=real64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(4) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type - integer(c_int64_t) :: strides(4) !! Strides for accessing data - integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(4) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(4) !! Strides for accessing data + integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -2093,7 +2745,7 @@ subroutine torch_tensor_from_array_real64_4d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -2105,15 +2757,15 @@ subroutine torch_tensor_from_array_real64_4d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -2121,7 +2773,7 @@ end subroutine torch_tensor_from_array_real64_4d !> Return a Torch tensor pointing to data_in array of rank 5 containing data of type `real64` subroutine torch_tensor_from_array_real64_5d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : real64 @@ -2131,15 +2783,15 @@ subroutine torch_tensor_from_array_real64_5d(tensor, data_in, layout, & ! inputs real(kind=real64), intent(in), target :: data_in(:,:,:,:,:) !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(5) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(5) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type - integer(c_int64_t) :: strides(5) !! Strides for accessing data - integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(5) !! Shape of the tensor + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type + integer(c_int64_t) :: strides(5) !! Strides for accessing data + integer(c_int), parameter :: ndims = 5 !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -2147,7 +2799,7 @@ subroutine torch_tensor_from_array_real64_5d(tensor, data_in, layout, & ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -2159,15 +2811,15 @@ subroutine torch_tensor_from_array_real64_5d(tensor, data_in, layout, & requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -2184,7 +2836,7 @@ subroutine torch_tensor_to_array_int8_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2199,7 +2851,7 @@ subroutine torch_tensor_to_array_int8_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int8_1d @@ -2214,7 +2866,7 @@ subroutine torch_tensor_to_array_int8_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2229,7 +2881,7 @@ subroutine torch_tensor_to_array_int8_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int8_2d @@ -2244,7 +2896,7 @@ subroutine torch_tensor_to_array_int8_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2259,7 +2911,7 @@ subroutine torch_tensor_to_array_int8_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int8_3d @@ -2274,7 +2926,7 @@ subroutine torch_tensor_to_array_int8_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2289,7 +2941,7 @@ subroutine torch_tensor_to_array_int8_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int8_4d @@ -2304,7 +2956,7 @@ subroutine torch_tensor_to_array_int8_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type + integer(c_int), parameter :: dtype = torch_kInt8 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2319,7 +2971,7 @@ subroutine torch_tensor_to_array_int8_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int8_5d @@ -2334,7 +2986,7 @@ subroutine torch_tensor_to_array_int16_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2349,7 +3001,7 @@ subroutine torch_tensor_to_array_int16_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int16_1d @@ -2364,7 +3016,7 @@ subroutine torch_tensor_to_array_int16_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2379,7 +3031,7 @@ subroutine torch_tensor_to_array_int16_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int16_2d @@ -2394,7 +3046,7 @@ subroutine torch_tensor_to_array_int16_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2409,7 +3061,7 @@ subroutine torch_tensor_to_array_int16_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int16_3d @@ -2424,7 +3076,7 @@ subroutine torch_tensor_to_array_int16_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2439,7 +3091,7 @@ subroutine torch_tensor_to_array_int16_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int16_4d @@ -2454,7 +3106,7 @@ subroutine torch_tensor_to_array_int16_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type + integer(c_int), parameter :: dtype = torch_kInt16 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2469,7 +3121,7 @@ subroutine torch_tensor_to_array_int16_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int16_5d @@ -2484,7 +3136,7 @@ subroutine torch_tensor_to_array_int32_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2499,7 +3151,7 @@ subroutine torch_tensor_to_array_int32_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int32_1d @@ -2514,7 +3166,7 @@ subroutine torch_tensor_to_array_int32_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2529,7 +3181,7 @@ subroutine torch_tensor_to_array_int32_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int32_2d @@ -2544,7 +3196,7 @@ subroutine torch_tensor_to_array_int32_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2559,7 +3211,7 @@ subroutine torch_tensor_to_array_int32_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int32_3d @@ -2574,7 +3226,7 @@ subroutine torch_tensor_to_array_int32_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2589,7 +3241,7 @@ subroutine torch_tensor_to_array_int32_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int32_4d @@ -2604,7 +3256,7 @@ subroutine torch_tensor_to_array_int32_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type + integer(c_int), parameter :: dtype = torch_kInt32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2619,7 +3271,7 @@ subroutine torch_tensor_to_array_int32_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int32_5d @@ -2634,7 +3286,7 @@ subroutine torch_tensor_to_array_int64_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2649,7 +3301,7 @@ subroutine torch_tensor_to_array_int64_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int64_1d @@ -2664,7 +3316,7 @@ subroutine torch_tensor_to_array_int64_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2679,7 +3331,7 @@ subroutine torch_tensor_to_array_int64_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int64_2d @@ -2694,7 +3346,7 @@ subroutine torch_tensor_to_array_int64_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2709,7 +3361,7 @@ subroutine torch_tensor_to_array_int64_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int64_3d @@ -2724,7 +3376,7 @@ subroutine torch_tensor_to_array_int64_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2739,7 +3391,7 @@ subroutine torch_tensor_to_array_int64_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int64_4d @@ -2754,7 +3406,7 @@ subroutine torch_tensor_to_array_int64_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type + integer(c_int), parameter :: dtype = torch_kInt64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2769,7 +3421,7 @@ subroutine torch_tensor_to_array_int64_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_int64_5d @@ -2784,7 +3436,7 @@ subroutine torch_tensor_to_array_real32_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2799,7 +3451,7 @@ subroutine torch_tensor_to_array_real32_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real32_1d @@ -2814,7 +3466,7 @@ subroutine torch_tensor_to_array_real32_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2829,7 +3481,7 @@ subroutine torch_tensor_to_array_real32_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real32_2d @@ -2844,7 +3496,7 @@ subroutine torch_tensor_to_array_real32_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2859,7 +3511,7 @@ subroutine torch_tensor_to_array_real32_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real32_3d @@ -2874,7 +3526,7 @@ subroutine torch_tensor_to_array_real32_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2889,7 +3541,7 @@ subroutine torch_tensor_to_array_real32_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real32_4d @@ -2904,7 +3556,7 @@ subroutine torch_tensor_to_array_real32_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat32 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2919,7 +3571,7 @@ subroutine torch_tensor_to_array_real32_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real32_5d @@ -2934,7 +3586,7 @@ subroutine torch_tensor_to_array_real64_1d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2949,7 +3601,7 @@ subroutine torch_tensor_to_array_real64_1d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real64_1d @@ -2964,7 +3616,7 @@ subroutine torch_tensor_to_array_real64_2d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -2979,7 +3631,7 @@ subroutine torch_tensor_to_array_real64_2d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real64_2d @@ -2994,7 +3646,7 @@ subroutine torch_tensor_to_array_real64_3d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -3009,7 +3661,7 @@ subroutine torch_tensor_to_array_real64_3d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real64_3d @@ -3024,7 +3676,7 @@ subroutine torch_tensor_to_array_real64_4d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -3039,7 +3691,7 @@ subroutine torch_tensor_to_array_real64_4d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real64_4d @@ -3054,7 +3706,7 @@ subroutine torch_tensor_to_array_real64_5d(tensor, data_out, sizes) integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type + integer(c_int), parameter :: dtype = torch_kFloat64 !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -3069,7 +3721,7 @@ subroutine torch_tensor_to_array_real64_5d(tensor, data_out, sizes) end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_real64_5d diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 785d9d2a..db02e2b6 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -115,6 +115,39 @@ module ftorch end function torch_from_blob_c end interface + interface assignment (=) + module procedure torch_tensor_assign + end interface + + interface operator (+) + module procedure torch_tensor_add + end interface + + interface operator (-) + module procedure torch_tensor_subtract + end interface + + interface operator (*) + module procedure torch_tensor_multiply + #:for PREC in PRECISIONS + module procedure torch_tensor_premultiply_${PREC}$ + module procedure torch_tensor_postmultiply_${PREC}$ + #:endfor + end interface + + interface operator (/) + module procedure torch_tensor_divide + #:for PREC in PRECISIONS + module procedure torch_tensor_postdivide_${PREC}$ + #:endfor + end interface + + interface operator (**) + #:for PREC in PRECISIONS + module procedure torch_tensor_power_${PREC}$ + #:endfor + end interface + interface function torch_to_blob_c(tensor, dtype) result(data) & bind(c, name = 'torch_to_blob') @@ -386,6 +419,199 @@ contains call torch_tensor_delete_c(tensor%p) end subroutine torch_tensor_delete + !> Overloads assignment operator for tensors. + subroutine torch_tensor_assign(output, input) + type(torch_tensor), intent(out) :: output + type(torch_tensor), intent(in) :: input + + interface + function torch_tensor_assign_c(input_c) result(output_c) & + bind(c, name = 'torch_tensor_assign') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: input_c + type(c_ptr) :: output_c + end function torch_tensor_assign_c + end interface + + output%p = torch_tensor_assign_c(input%p) + end subroutine torch_tensor_assign + + !> Overloads addition operator for two tensors. + function torch_tensor_add(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_add_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_add') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_add_c + end interface + + output%p = torch_tensor_add_c(tensor1%p, tensor2%p) + end function torch_tensor_add + + !> Overloads subtraction operator for two tensors. + function torch_tensor_subtract(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_subtract_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_subtract') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_subtract_c + end interface + + output%p = torch_tensor_subtract_c(tensor1%p, tensor2%p) + end function torch_tensor_subtract + + !> Overloads multiplication operator for two tensors. + function torch_tensor_multiply(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_multiply_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_multiply') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_multiply_c + end interface + + output%p = torch_tensor_multiply_c(tensor1%p, tensor2%p) + end function torch_tensor_multiply + + #:for PREC in PRECISIONS + !> Overloads multiplication operator for a scalar of type ${PREC}$ and a tensor. + function torch_tensor_premultiply_${PREC}$(scalar, tensor) result(output) + ${f_type(PREC)}$(kind=${PREC}$), intent(in) :: scalar + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: output + + interface + function torch_tensor_premultiply_c(scalar_c, tensor_c) result(output_c) & + bind(c, name = 'torch_tensor_premultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : ${PREC}$ + implicit none + ${f_type(PREC)}$(kind=${PREC}$), value, intent(in) :: scalar_c + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: output_c + end function torch_tensor_premultiply_c + end interface + + output%p = torch_tensor_premultiply_c(scalar, tensor%p) + end function torch_tensor_premultiply_${PREC}$ + + #:endfor + + #:for PREC in PRECISIONS + !> Overloads multiplication operator for a tensor and a scalar of type ${PREC}$. + function torch_tensor_postmultiply_${PREC}$(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + ${f_type(PREC)}$(kind=${PREC}$), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postmultiply_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postmultiply') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : ${PREC}$ + implicit none + type(c_ptr), value, intent(in) :: tensor_c + ${f_type(PREC)}$(kind=${PREC}$), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postmultiply_c + end interface + + output%p = torch_tensor_postmultiply_c(tensor%p, scalar) + end function torch_tensor_postmultiply_${PREC}$ + + #:endfor + !> Overloads division operator for two tensors. + function torch_tensor_divide(tensor1, tensor2) result(output) + type(torch_tensor), intent(in) :: tensor1 + type(torch_tensor), intent(in) :: tensor2 + type(torch_tensor) :: output + + interface + function torch_tensor_divide_c(tensor1_c, tensor2_c) result(output_c) & + bind(c, name = 'torch_tensor_divide') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor1_c + type(c_ptr), value, intent(in) :: tensor2_c + type(c_ptr) :: output_c + end function torch_tensor_divide_c + end interface + + output%p = torch_tensor_divide_c(tensor1%p, tensor2%p) + end function torch_tensor_divide + + #:for PREC in PRECISIONS + !> Overloads division operator for a tensor and a scalar of type ${PREC}$. + function torch_tensor_postdivide_${PREC}$(tensor, scalar) result(output) + type(torch_tensor), intent(in) :: tensor + ${f_type(PREC)}$(kind=${PREC}$), intent(in) :: scalar + type(torch_tensor) :: output + + interface + function torch_tensor_postdivide_c(tensor_c, scalar_c) & + result(output_c) bind(c, name = 'torch_tensor_postdivide') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : ${PREC}$ + implicit none + type(c_ptr), value, intent(in) :: tensor_c + ${f_type(PREC)}$(kind=${PREC}$), value, intent(in) :: scalar_c + type(c_ptr) :: output_c + end function torch_tensor_postdivide_c + end interface + + output%p = torch_tensor_postdivide_c(tensor%p, scalar) + end function torch_tensor_postdivide_${PREC}$ + + #:endfor + + #:for PREC in PRECISIONS + !> Overloads exponentiation operator for a tensor and a scalar of type `${PREC}$` + function torch_tensor_power_${PREC}$(tensor, power) result(output) + type(torch_tensor), intent(in) :: tensor + ${f_type(PREC)}$(kind=${PREC}$), intent(in) :: power + type(torch_tensor) :: output + + interface + function torch_tensor_power_c(tensor_c, power_c) result(output_c) & + bind(c, name = 'torch_tensor_power') + use, intrinsic :: iso_c_binding, only : c_ptr + use, intrinsic :: iso_fortran_env, only : ${PREC}$ + implicit none + type(c_ptr), value, intent(in) :: tensor_c + ${f_type(PREC)}$(kind=${PREC}$), value, intent(in) :: power_c + type(c_ptr) :: output_c + end function torch_tensor_power_c + end interface + + output%p = torch_tensor_power_c(tensor%p, power) + end function torch_tensor_power_${PREC}$ + + #:endfor + ! Torch Model API !> Loads a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript) subroutine torch_model_load(model, filename, device_type, device_index, & @@ -524,7 +750,7 @@ contains #:for RANK in RANKS !> Return a Torch tensor pointing to data_in array of rank ${RANK}$ containing data of type `${PREC}$` subroutine torch_tensor_from_array_${PREC}$_${RANK}$d(tensor, data_in, layout, & - c_device_type, device_index, requires_grad) + device_type, device_index, requires_grad) use, intrinsic :: iso_c_binding, only : c_bool, c_float, c_int, c_int64_t, c_loc use, intrinsic :: iso_fortran_env, only : ${PREC}$ @@ -534,15 +760,15 @@ contains ! inputs ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$ !! Input data that tensor will point at integer(ftorch_int), intent(in) :: layout(${RANK}$) !! Control order of indices - integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) + integer(c_int), intent(in) :: device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`) integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor ! local data - integer(c_int64_t) :: c_tensor_shape(${RANK}$) !! Shape of the tensor - integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! Data type - integer(c_int64_t) :: strides(${RANK}$) !! Strides for accessing data - integer(c_int), parameter :: ndims = ${RANK}$ !! Number of dimension of input data + integer(c_int64_t) :: tensor_shape(${RANK}$) !! Shape of the tensor + integer(c_int), parameter :: dtype = ${enum_from_prec(PREC)}$ !! Data type + integer(c_int64_t) :: strides(${RANK}$) !! Strides for accessing data + integer(c_int), parameter :: ndims = ${RANK}$ !! Number of dimension of input data integer(ftorch_int) :: i integer(c_int) :: device_index_value logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor @@ -550,7 +776,7 @@ contains ! Process optional arguments if (present(device_index)) then device_index_value = device_index - else if (c_device_type == torch_kCPU) then + else if (device_type == torch_kCPU) then device_index_value = -1 else device_index_value = 0 @@ -562,15 +788,15 @@ contains requires_grad_value = requires_grad end if - c_tensor_shape = shape(data_in) + tensor_shape = shape(data_in) strides(layout(1)) = 1 do i = 2, ndims - strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1)) + strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1)) end do - tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, & - strides, c_dtype, c_device_type, & + tensor%p = torch_from_blob_c(c_loc(data_in), ndims, tensor_shape, & + strides, dtype, device_type, & device_index_value, & logical(requires_grad_value, c_bool)) @@ -591,7 +817,7 @@ contains integer(kind=int64), allocatable :: my_shape(:) !! Number of entries for each rank ! Local data - integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! Data type + integer(c_int), parameter :: dtype = ${enum_from_prec(PREC)}$ !! Data type type(c_ptr) :: cptr my_shape = tensor%get_shape() @@ -606,7 +832,7 @@ contains end if ! Have the data_out array point to the Tensor data - cptr = torch_to_blob_c(tensor%p, c_dtype) + cptr = torch_to_blob_c(tensor%p, dtype) call c_f_pointer(cptr, data_out, my_shape) end subroutine torch_tensor_to_array_${PREC}$_${RANK}$d