-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overload elementary operations #139
Merged
Merged
Changes from 47 commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
cc9c76c
Overload =,+,*; consistent variable names
jwallwork23 366ffda
Update existing examples
jwallwork23 852b154
Add autograd example
jwallwork23 f7bca7e
Apply Black
jwallwork23 5232193
Overload - and test
jwallwork23 a730a40
Overload / and test
jwallwork23 7abeba7
Overload ** and test
jwallwork23 c650342
Overload scalar premultiply and test
jwallwork23 c0c5507
Overload scalar postmultiply and test
jwallwork23 6aa3768
Tidy overloading test - FIXME
jwallwork23 7b9126e
Formatting
jwallwork23 e6c7675
Reformatting
jwallwork23 9ac1087
Merge branch 'main' into autograd
jwallwork23 528b4f9
Fix merge issue
jwallwork23 dfaba4d
Include example 3 for consistency
jwallwork23 202f77f
DO NOT MERGE (debugging)
jwallwork23 1eb6d52
Implement torch_to_blob on C++ side
jwallwork23 60ce4a6
Implement Fortran interface
jwallwork23 684a711
Use torch_tensor_to_array in example 1
jwallwork23 d35218e
Use correct data types; raise errors for unsupported cases
jwallwork23 b5f53f2
Revert changes to example 1
jwallwork23 cf9a006
Add beginnings of autograd demo
jwallwork23 b8b4840
Docs for example 5
jwallwork23 806730c
More detail on uint8 and float16 not being supported
jwallwork23 0f6ee05
Add notes on float types
jwallwork23 e2f6a9d
Handle allocation of pointer array
jwallwork23 5881a3b
Merge branch 'torch_tensor_to_array' into autograd_toarray
jwallwork23 cdd4433
Merge fixes
jwallwork23 f3bf975
Merge branch 'main' into autograd
jwallwork23 6223f36
Update autograd example
jwallwork23 e055d8b
Merge branch 'main' into autograd
jwallwork23 5364800
Use assert_allclose in new code
jwallwork23 80e5a71
Merge branch 'main' into autograd
jwallwork23 59ffdc4
Use bare import for autograd
jwallwork23 f504f5e
Lint
jwallwork23 8379001
Lint
jwallwork23 a590955
Apply clang-format
jwallwork23 c12d5c3
Apply clang-format to header
jwallwork23 cc884c6
Merge branch 'main' into autograd
jwallwork23 d9000e1
Post-merge fixes
jwallwork23 c87fd97
test: make windows CI more robust
TomMelt 98917d3
chore: rename variable to torch_path
TomMelt f0ed978
Reformulate autograd example to test multiply and divide
jwallwork23 b804d54
Implement postdivide
jwallwork23 b239207
Merge branch 'test-simple-change' into autograd
jwallwork23 c0a3fe4
Point to Torch C++ API
jwallwork23 1e852b4
Update docs on autograd; add reference to looping example
jwallwork23 8dfa92c
Revert adding example 3 to build
jwallwork23 70f28a4
Merge branch 'main' into autograd
jwallwork23 4406fd8
Write as LibTorch in pages/autograd.md
jwallwork23 e82a9db
Use better links for Torch C++ docs
jwallwork23 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,9 @@ program example | |
use, intrinsic :: iso_fortran_env, only : sp => real32 | ||
|
||
! Import our library for interfacing with PyTorch's Autograd module | ||
use ftorch, only : torch_tensor, torch_kCPU, & | ||
torch_tensor_from_array, torch_tensor_to_array, torch_tensor_delete | ||
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), & | ||
operator(/), operator(**), torch_kCPU, torch_tensor, torch_tensor_delete, & | ||
torch_tensor_from_array, torch_tensor_to_array | ||
|
||
! Import our tools module for testing utils | ||
use ftorch_test_utils, only : assert_allclose | ||
|
@@ -16,8 +17,9 @@ program example | |
integer, parameter :: wp = sp | ||
|
||
! Set up Fortran data structures | ||
integer, parameter :: n=2, m=5 | ||
real(wp), dimension(n,m), target :: in_data | ||
integer, parameter :: n=2, m=1 | ||
real(wp), dimension(n,m), target :: in_data1 | ||
real(wp), dimension(n,m), target :: in_data2 | ||
real(wp), dimension(:,:), pointer :: out_data | ||
real(wp), dimension(n,m) :: expected | ||
integer :: tensor_layout(2) = [1, 2] | ||
|
@@ -27,45 +29,78 @@ program example | |
logical :: test_pass | ||
|
||
! Set up Torch data structures | ||
type(torch_tensor) :: tensor | ||
type(torch_tensor) :: a, b, Q | ||
|
||
! initialize in_data with some fake data | ||
do j = 1, m | ||
do i = 1, n | ||
in_data(i,j) = ((i-1)*m + j) * 1.0_wp | ||
end do | ||
end do | ||
! Initialise input arrays as in Python example | ||
in_data1(:,1) = [2.0_wp, 3.0_wp] | ||
in_data2(:,1) = [6.0_wp, 4.0_wp] | ||
|
||
! Construct a Torch Tensor from a Fortran array | ||
call torch_tensor_from_array(tensor, in_data, tensor_layout, torch_kCPU) | ||
! TODO: Implement requires_grad=.true. | ||
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU) | ||
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU) | ||
|
||
! check tensor rank and shape match those of in_data | ||
if (tensor%get_rank() /= 2) then | ||
if ((a%get_rank() /= 2) .or. (b%get_rank() /= 2)) then | ||
print *, "Error :: rank should be 2" | ||
stop 1 | ||
end if | ||
if (any(tensor%get_shape() /= [2, 5])) then | ||
print *, "Error :: shape should be (2, 5)" | ||
if (any(a%get_shape() /= [n, m]) .or. any(b%get_shape() /= [n, m])) then | ||
write(6,"('Error :: shape should be (',i1,', ',i1,')')") n, m | ||
stop 1 | ||
end if | ||
|
||
! Check arithmetic operations work for torch_tensors | ||
write (*,*) "a = ", in_data1(:,1) | ||
write (*,*) "b = ", in_data2(:,1) | ||
Q = 3 * (a**3 - b * b / 3) | ||
|
||
! Extract a Fortran array from a Torch tensor | ||
call torch_tensor_to_array(tensor, out_data, shape(in_data)) | ||
call torch_tensor_to_array(Q, out_data, shape(in_data1)) | ||
write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:,1) | ||
|
||
! Check output tensor matches expected value | ||
expected(:,:) = in_data | ||
expected(:,1) = [-12.0_wp, 65.0_wp] | ||
test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5) | ||
if (.not. test_pass) then | ||
call clean_up() | ||
print *, "Error :: out_data does not match expected value" | ||
stop 999 | ||
end if | ||
|
||
! Check that the data match | ||
! Check first input array is unchanged by the arithmetic operations | ||
expected(:,1) = [2.0_wp, 3.0_wp] | ||
test_pass = assert_allclose(in_data1, expected, test_name="torch_tensor_to_array", rtol=1e-5) | ||
if (.not. test_pass) then | ||
print *, "Error :: in_data does not match out_data" | ||
call clean_up() | ||
print *, "Error :: in_data1 was changed during arithmetic operations" | ||
stop 999 | ||
end if | ||
|
||
! Cleanup | ||
nullify(out_data) | ||
call torch_tensor_delete(tensor) | ||
! Check second input array is unchanged by the arithmetic operations | ||
expected(:,1) = [6.0_wp, 4.0_wp] | ||
test_pass = assert_allclose(in_data2, expected, test_name="torch_tensor_to_array", rtol=1e-5) | ||
if (.not. test_pass) then | ||
call clean_up() | ||
print *, "Error :: in_data2 was changed during arithmetic operations" | ||
stop 999 | ||
end if | ||
|
||
! Back-propagation | ||
! TODO: Requires API extension | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neither will this one. (See #158.) |
||
|
||
! Cleanup | ||
call clean_up() | ||
write (*,*) "Autograd example ran successfully" | ||
|
||
contains | ||
|
||
! Subroutine for freeing memory and nullifying pointers used in the example | ||
subroutine clean_up() | ||
nullify(out_data) | ||
call torch_tensor_delete(a) | ||
call torch_tensor_delete(b) | ||
call torch_tensor_delete(Q) | ||
end subroutine clean_up | ||
|
||
jatkinson1000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end program example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TomMelt marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
title: Online training | ||
|
||
[TOC] | ||
|
||
## Current state | ||
|
||
FTorch has supported offline training of ML models for some time. We are | ||
currently working on extending its functionality to support online training, | ||
too. This will involve exposing the automatic differentiation and | ||
back-propagation functionality in PyTorch/libtorch. | ||
jwallwork23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
In the following, we document a workplan of the related functionality. Each step | ||
below will be updated upon completion. | ||
|
||
### Operator overloading | ||
|
||
Mathematical operators involving Tensors are overloaded, so that we can compute | ||
expressions involving outputs from one or more ML models. | ||
|
||
Whilst it's possible to import such functionality with a bare | ||
```fortran | ||
use ftorch | ||
``` | ||
statement, the best practice is to import specifically the operators that you | ||
wish to use. Note that the assignment operator `=` has a slightly different | ||
notation: | ||
``` | ||
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), & | ||
operator(/), operator(**) | ||
``` | ||
|
||
For a concrete example of how to compute mathematical expressions involving | ||
Torch tensors, see the associated | ||
[worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/6_Autograd). | ||
|
||
### The `requires_grad` property | ||
|
||
*Not yet implemented.* | ||
|
||
### The `backward` operator | ||
|
||
*Not yet implemented.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO won't be addressed in this PR.