Skip to content

Commit

Permalink
Refactored following review
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Dec 18, 2024
1 parent 1bc0113 commit 919d0ad
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ Keep it human-readable, your future self will thank you!
- feat: Support for multiple edge builders between two sets of nodes (#70)
- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93)
- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)
- feat: Add `AttributeFromNode` edge attribute to copy attribute from source or destination node. Set `node_attr_name` and `node_type : src | dst` in the config to specify which attribute to copy from the source | destination node (#94)

- feat: Add `AttributeFromSourceNode` and `AttributeFromTargetNode` edge attribute to copy attribute from source or target node. Set `node_attr_name` in the config to specify which attribute to copy from the source | target node (#94)

# Changed

Expand Down
88 changes: 46 additions & 42 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
class BaseEdgeAttribute(ABC, NormaliserMixin):
"""Base class for edge attributes."""

def __init__(self, norm: str | None = None) -> None:
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
self.norm = norm
self.dtype = dtype

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...
Expand All @@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
if values.ndim == 1:
values = values[:, np.newaxis]

normed_values = self.normalise(values)
norm_values = self.normalise(values)

return torch.tensor(normed_values, dtype=torch.float32)
return torch.tensor(norm_values.astype(self.dtype))

def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
Expand Down Expand Up @@ -157,36 +158,18 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
return values


class BooleanBaseEdgeAttribute:
class BooleanBaseEdgeAttribute(BaseEdgeAttribute):
"""Base class for boolean edge attributes."""

def __init__(self) -> None:
pass

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
return torch.tensor(values, dtype=torch.bool)

def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
source_name, _, target_name = edges_name
assert (
source_name in graph.node_types
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
assert (
target_name in graph.node_types
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."

values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
return self.post_process(values)
super().__init__(norm=None, dtype="bool")


class AttributeFromNode(BooleanBaseEdgeAttribute):
"""
Copy an attribute of either the source or destination node to the edge.
Base class for Attribute from Node.
Copy an attribute of either the source or target node to the edge.
Accesses origin/target node attribute and propagates it to the edge.
Used for example to identify if an encoder edge originates from a LAM or global node.
Expand All @@ -195,34 +178,55 @@ class AttributeFromNode(BooleanBaseEdgeAttribute):
node_attr_name : str
Name of the node attribute to propagate.
node_type : str
Pick the node to copy from. Options: "src, dst"
Methods
-------
get_node_name(source_name, target_name)
Return the name of the node to copy.
get_raw_values(graph, source_name, target_name)
Computes the edge attribute from the source or destination node attribute.
Computes the edge attribute from the source or target node attribute.
"""

def __init__(self, node_attr_name: str, node_type: str) -> None:
def __init__(self, node_attr_name: str) -> None:
super().__init__()
self.node_attr_name = node_attr_name
assert node_type in ["src", "dst"]
self.node_type = node_type
self.idx = None

@abstractmethod
def get_node_name(self, source_name: str, target_name: str): ...

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:

node_name = self.get_node_name(source_name, target_name)

edge_index = graph[(source_name, "to", target_name)].edge_index
assert hasattr(graph[node_name], self.node_attr_name)
val = getattr(graph[node_name], self.node_attr_name).numpy()[edge_index[self.idx]]
return val

if self.node_type == "src":
name_to_copy = source_name
idx = 0

else:
name_to_copy = target_name
idx = 1
class AttributeFromSourceNode(AttributeFromNode):
"""
Copy an attribute of the source node to the edge.
"""

assert hasattr(graph[name_to_copy], self.node_attr_name)
def __init__(self, node_attr_name: str) -> None:
super().__init__(node_attr_name)
self.idx = 0

val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]]
def get_node_name(self, source_name: str, target_name: str):
return source_name

return val

class AttributeFromTargetNode(AttributeFromNode):
"""
Copy an attribute of the target node to the edge.
"""

def __init__(self, node_attr_name: str) -> None:
super().__init__(node_attr_name)
self.idx = 1

def get_node_name(self, source_name: str, target_name: str):
return target_name

0 comments on commit 919d0ad

Please sign in to comment.