Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not create flat multiscale graphs with >= 3 levels #33

Closed
joeloskarsson opened this issue Oct 13, 2024 · 4 comments · Fixed by #41
Closed

Can not create flat multiscale graphs with >= 3 levels #33

joeloskarsson opened this issue Oct 13, 2024 · 4 comments · Fixed by #41
Assignees
Labels
bug Something isn't working
Milestone

Comments

@joeloskarsson
Copy link
Contributor

There seems to be a bug in

def create_flat_multiscale_mesh_graph(

that causes a crash whenever you try to create a multiscale graph by collapsing >= 3 levels of flat graphs.

Minimum example to reproduce

import weather_model_graphs as wmg
import numpy as np

x_coords = np.linspace(0, 15, 50)
y_coords = np.linspace(0, 5, 50)
meshgridded = np.meshgrid(x_coords, y_coords)
xy_grid = np.stack(meshgridded, axis=0)

graph = wmg.create.archetype.create_graphcast_graph(
    xy_grid,
    grid_refinement_factor=1,
    level_refinement_factor=3,
)

This gives output

DEBUG    | weather_model_graphs.create.mesh.mesh:create_multirange_2d_mesh_graphs:134 - mesh_levels: 3, nleaf: [27 27]

but then crashes with

File "/home/joel/repos/weather-model-graphs/src/weather_model_graphs/create/mesh/kinds/flat.py", line 65, in create_flat_multiscale_mesh_graph
    .reshape((num_nodes_x, num_nodes_y, 2))[
ValueError: cannot reshape array of size 162 into shape (26,26,2)

Problem

The issue seems to relate to the loop

for lev in range(1, len(G_all_levels)):
nodes = list(G_all_levels[lev - 1].nodes)
# Last nodes always has pos (nx-1, ny-1)
num_nodes_x = nodes[-1][0] + 1
num_nodes_y = nodes[-1][1] + 1
ij = (
np.array(nodes)
.reshape((num_nodes_x, num_nodes_y, 2))[
level_offset::level_refinement_factor,
level_offset::level_refinement_factor,
:,
]
.reshape(int(num_nodes_x * num_nodes_y / (level_refinement_factor**2)), 2)
)
ij = [tuple(x) for x in ij]
G_all_levels[lev] = networkx.relabel_nodes(
G_all_levels[lev], dict(zip(G_all_levels[lev].nodes, ij))
)
G_tot = networkx.compose(G_tot, G_all_levels[lev])

where the variables num_nodes_x and num_nodes_y are not set correctly on the second iteration. I am a bit suspicious of the lines
G_all_levels[lev] = networkx.relabel_nodes(
G_all_levels[lev], dict(zip(G_all_levels[lev].nodes, ij))
)

which modify the graph G_all_levels[lev], which is what is used in the next iteration of the loop. It could be that this graph is being overwritten before it is used.

I encountered this when working on #32, but since this is an orthogonal issue I will not fix it in there.

@joeloskarsson joeloskarsson added the bug Something isn't working label Oct 13, 2024
@joeloskarsson joeloskarsson modified the milestones: v0.3.0 (proposed), v0.3.0 Jan 12, 2025
@matschreiner
Copy link

matschreiner commented Jan 14, 2025

Hi @joeloskarsson .
I just started on DMI as a new developer in the MLops team :)
I picked up this issue as my first task

Seems like there has been a bit of modifications since the issue was posted - the grid_refinement_factor has been renamed to mesh_node_distance as far as I understand, and the mesh has to be stacked on the last dimension instead of the first!

Running the following code doesn't crash:

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

import weather_model_graphs as wmg


def main():
    x_coords = np.linspace(0, 15, 10)
    y_coords = np.linspace(0, 5, 10)
    meshgridded = np.meshgrid(x_coords, y_coords)
    xy_grid = np.stack(meshgridded, axis=-1)
    xy_grid = xy_grid.reshape(-1, 2)

    fig, ax = plt.subplots(1, 2, figsize=(16, 8))

    for i, lrf in enumerate([3, 5]):
        graph = wmg.create.archetype.create_graphcast_graph(
            xy_grid,
            mesh_node_distance=1,
            level_refinement_factor=lrf,
        )

        wmg.visualise.nx_draw_with_pos_and_attr(graph, ax[i])
    plt.show()


main()

However, output graph is not exactly as I would expect, I get, producint a regular grid when level_refinement_factor=5 and an irregular one when level_refinement_factor=3:

graph

@joeloskarsson
Copy link
Contributor Author

joeloskarsson commented Jan 14, 2025

Hi @matschreiner! Happy to see you looking into this 😄

You are absolutely correct that there have been some changes now so my code from above does not work. Here is a modified example that should create the same crash:

import weather_model_graphs as wmg
import numpy as np

x_coords = np.linspace(0, 15, 50)
y_coords = np.linspace(0, 5, 50)
meshgridded = np.meshgrid(x_coords, y_coords)
xy_grid = np.stack(meshgridded, axis=-1)
xy = xy_grid.reshape(-1, 2)

graph = wmg.create.archetype.create_graphcast_graph(
    xy,
    mesh_node_distance=0.5,
    level_refinement_factor=3,
)

With this mesh_node_distance the resulting graph contains 3 collapsed mesh levels, which is when the crash happens.

Regarding your example: The reason that these examples work is that they only create one level of mesh nodes, that is then collapsed. Note in the output

mesh_levels: 1, nleaf: [9 3]

and

mesh_levels: 1, nleaf: [5 5]

So the leftmost graph has 9x3 mesh nodes and the rightmost has 5x5 mesh nodes. (this is a bit easier to see with wmg.visualise.nx_draw_with_pos_and_attr(graph, ax[i], edge_color_attr="component", node_color_attr="type") ). I think that this is actually the intended behavior here, even though it might not be very obvious why. It's important to keep in mind the scaling of the x- and y-axis in the plots, that these coordinates actually cover a rectangular area. Then it is the fact that for the graphcast archetype the number of mesh nodes in each direction has to be level_refinement_factor^n for some whole number n > 0 which makes these number of mesh nodes for this example.

@joeloskarsson
Copy link
Contributor Author

I did a quick test here to see if it was indeed

G_all_levels[lev] = networkx.relabel_nodes(
G_all_levels[lev], dict(zip(G_all_levels[lev].nodes, ij))
)
that was the cause of this problem, but it seems that the issue is more complex than that. New hypothesis is that num_nodes_x and num_nodes_y are not getting set correctly, but I did not dig deeper than that right now.

@joeloskarsson
Copy link
Contributor Author

I think I figured out the problem with this now, which is indeed that num_nodes_x and num_nodes_y are not getting set correctly at each iteration of the loop. Luckily I think we can compute the correct updates to these really easily.

I implemented a fix here that seems to work: joeloskarsson@30bfb82 But requires a little bit more testing. Will prepare a proper PR after testing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants