-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added method to reshape dims of network. #9
base: main
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 1693744455Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks good, but I left some questions about the handling of the error cases.
node = random_node(dim_edges) | ||
network = Network(node[:], [], copy=False) | ||
with pytest.raises(ValueError): | ||
network.match_out_dims(target_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably worthwhile asserting that network
is in a valid state after an error has been raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I will do so.
Co-authored-by: Simon Cross <hodgestar+github@gmail.com>
Thank you for the review Simon! |
I changed so that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple more suggestions. Thank you for sorting out the state of the network post errors by copying the network -- I think that was a good solution.
``target_dims``. After this function the following will hold: | ||
``network.dims[0] == target_dims`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``target_dims``. After this function the following will hold: | |
``network.dims[0] == target_dims`` | |
``target_dims``. The returned network will have ``network.dims[0] == target_dims``. |
``target_dims``. After this function the following will hold: | ||
``network.dims[1] == target_dims`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``target_dims``. After this function the following will hold: | |
``network.dims[1] == target_dims`` | |
``target_dims``. The returned network will have ``network.dims[1] == target_dims``. |
Parameters | ||
---------- | ||
target_dims: list of int | ||
Desired dimensions for ``out_edges``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Desired dimensions for ``out_edges``. | |
Desired dimensions for ``in_edges``. |
>>> array = np.random.random(network_dim) | ||
>>> node = tn.Node(array) | ||
>>> network = Network(node[:], []) | ||
>>> network = network.match_out_dims(target_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> network = network.match_out_dims(target_dims) | |
>>> network = network.match_in_dims(target_dims) |
new_edges = [] | ||
|
||
if len(edges) == 0 and len(target_dims) == 0: | ||
return _edges |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this just be return edges
? _edges
does not seem to be defined any more. We should add a test for the empty edges and target_dims case.
"edges are not compatible. The dimensions for edges is " | ||
+ str(e_dims) | ||
+ ", whereas the target dimension is" | ||
+ str(_target_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"edges are not compatible. The dimensions for edges is " | |
+ str(e_dims) | |
+ ", whereas the target dimension is" | |
+ str(_target_dims) | |
f"edges are not compatible. The dimensions for edges is {e_dims}" | |
f", whereas the target dimension is {_target_dims}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should just use f-strings these days.
|
||
e_dims = [e.dimension for e in edges] | ||
|
||
new_edges = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move the initialization of new_edges
to just before the top of the loop so that it's closer to where it's used. Currently we define it, then don't use it for awhile, then suddenly use it in the middle of a loop.
This PR implements two methods,
match_in_dims
andmatch_out_dims
that allows splitting edges to match a target dimension. This is primarily meant to be used withptrace
as we need to reshape the outer nodes such that the network has the appropriate dimensions.Note that
_match_edges_by_split
is a very similar function but with different scope. The new_match_dimensions
is meant to be used with unary operators that required a change in the dimension (ptrace) whereas_match_edges_by_split
is meant to be used with binary operators (matmul). This is because for_match_dimensions
(unary version) we do not allow merging two edge dimension. However, such operation can also be achieved in_match_edges_by_split
(binary version) by just splitting the edges as two list of edges is available.The reason to not allow here in
_match_dimensions
to merge edges is because this is a relatively complex operation that will require contracting nodes. I am assuming that it is best to do the contraction of the full network in a single step such that an optimized path is used.