forked from ecmwf/anemoi-graphs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into feature/edge_attr_from_node_attr
- Loading branch information
Showing
8 changed files
with
215 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
####################################### | ||
From latitude & longitude coordinates | ||
####################################### | ||
|
||
Nodes can also be created directly using latitude and longitude | ||
coordinates. Below is an example demonstrating how to add these nodes to | ||
a graph: | ||
|
||
.. code:: python | ||
from anemoi.graphs.nodes import LatLonNodes | ||
... | ||
lats = np.array([45.0, 45.0, 40.0, 40.0]) | ||
lons = np.array([5.0, 10.0, 10.0, 5.0]) | ||
graph = LatLonNodes(latitudes=lats, longitudes=lons, name="my_nodes").update_graph(graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
################ | ||
From text file | ||
################ | ||
|
||
To define the `node coordinates` based on a `.txt` file, you can | ||
configure the `.yaml` as follows: | ||
|
||
.. code:: yaml | ||
nodes: | ||
data: # name of the nodes | ||
node_builder: | ||
_target_: anemoi.graphs.nodes.TextNodes | ||
dataset: my_file.txt | ||
idx_lon: 0 | ||
idx_lat: 1 | ||
Here, dataset refers to the path of the `.txt` file that contains the | ||
latitude and longitude values in the columns specified by `idx_lat` and | ||
`idx_lon`, respectively. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# (C) Copyright 2024 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from anemoi.graphs.nodes.builders.base import BaseNodeBuilder | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class LatLonNodes(BaseNodeBuilder): | ||
"""Nodes from its latitude and longitude positions (in numpy arrays). | ||
Attributes | ||
---------- | ||
latitudes : list | np.ndarray | ||
The latitude of the nodes, in degrees. | ||
longitudes : list | np.ndarray | ||
The longitude of the nodes, in degrees. | ||
Methods | ||
------- | ||
get_coordinates() | ||
Get the lat-lon coordinates of the nodes. | ||
register_nodes(graph, name) | ||
Register the nodes in the graph. | ||
register_attributes(graph, name, config) | ||
Register the attributes in the nodes of the graph specified. | ||
update_graph(graph, name, attrs_config) | ||
Update the graph with new nodes and attributes. | ||
""" | ||
|
||
def __init__(self, latitudes: list[float] | np.ndarray, longitudes: list[float] | np.ndarray, name: str) -> None: | ||
super().__init__(name) | ||
self.latitudes = latitudes if isinstance(latitudes, np.ndarray) else np.array(latitudes) | ||
self.longitudes = longitudes if isinstance(longitudes, np.ndarray) else np.array(longitudes) | ||
|
||
assert len(self.latitudes) == len( | ||
self.longitudes | ||
), f"Lenght of latitudes and longitudes must match but {len(self.latitudes)}!={len(self.longitudes)}." | ||
assert self.latitudes.ndim == 1 or ( | ||
self.latitudes.ndim == 2 and self.latitudes.shape[1] == 1 | ||
), "latitudes must have shape (N, ) or (N, 1)." | ||
assert self.longitudes.ndim == 1 or ( | ||
self.longitudes.ndim == 2 and self.longitudes.shape[1] == 1 | ||
), "longitudes must have shape (N, ) or (N, 1)." | ||
|
||
def get_coordinates(self) -> torch.Tensor: | ||
"""Get the coordinates of the nodes. | ||
Returns | ||
------- | ||
torch.Tensor of shape (num_nodes, 2) | ||
A 2D tensor with the coordinates, in radians. | ||
""" | ||
return self.reshape_coords(self.latitudes, self.longitudes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# (C) Copyright 2024 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
import torch | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.graphs.nodes.attributes import AreaWeights | ||
from anemoi.graphs.nodes.attributes import UniformWeights | ||
from anemoi.graphs.nodes.builders.from_vectors import LatLonNodes | ||
|
||
lats = [45.0, 45.0, 40.0, 40.0] | ||
lons = [5.0, 10.0, 10.0, 5.0] | ||
|
||
|
||
def test_init(): | ||
"""Test LatLonNodes initialization.""" | ||
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes") | ||
assert isinstance(node_builder, LatLonNodes) | ||
|
||
|
||
def test_fail_init_length_mismatch(): | ||
"""Test LatLonNodes initialization with invalid argument.""" | ||
lons = [5.0, 10.0, 10.0, 5.0, 5.0] | ||
|
||
with pytest.raises(AssertionError): | ||
LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes") | ||
|
||
|
||
def test_fail_init_missing_argument(): | ||
"""Test NPZFileNodes initialization with missing argument.""" | ||
with pytest.raises(TypeError): | ||
LatLonNodes(name="test_nodes") | ||
|
||
|
||
def test_register_nodes(): | ||
"""Test LatLonNodes register correctly the nodes.""" | ||
graph = HeteroData() | ||
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes") | ||
graph = node_builder.register_nodes(graph) | ||
|
||
assert graph["test_nodes"].x is not None | ||
assert isinstance(graph["test_nodes"].x, torch.Tensor) | ||
assert graph["test_nodes"].x.shape == (len(lats), 2) | ||
assert graph["test_nodes"].node_type == "LatLonNodes" | ||
|
||
|
||
@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) | ||
def test_register_attributes(graph_with_nodes: HeteroData, attr_class): | ||
"""Test LatLonNodes register correctly the weights.""" | ||
node_builder = LatLonNodes(latitudes=lats, longitudes=lons, name="test_nodes") | ||
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} | ||
|
||
graph = node_builder.register_attributes(graph_with_nodes, config) | ||
|
||
assert graph["test_nodes"]["test_attr"] is not None | ||
assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) | ||
assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] |