Skip to content

Commit

Permalink
correl nan_indices_max with Pythran
Browse files Browse the repository at this point in the history
  • Loading branch information
paugier committed Apr 16, 2024
1 parent a05fc01 commit 57e5484
Showing 1 changed file with 93 additions and 9 deletions.
102 changes: 93 additions & 9 deletions src/fluidimage/calcul/correl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from numpy.fft import fft2, ifft2
from scipy.ndimage import correlate
from scipy.signal import correlate2d
from transonic import boost
from transonic import Array, Type, boost

from .correl_pycuda import correl_pycuda
from .errors import PIVError
Expand All @@ -75,8 +75,64 @@ def parse_displacement_max(displ_max, im0_shape):
return displ_max


def _compute_indices_max(correl, norm):
iy, ix = np.unravel_index(np.nanargmax(correl), correl.shape)
A2D = Array[Type(np.float32, np.float64), "2d", "C"]


@boost
def nan_indices_max(
correl: A2D,
i0_start: np.int32,
i0_stop: np.int32,
i1_start: np.int32,
i1_stop: np.int32,
):

correl_max = np.nan

# first, get the first non nan value
n0, n1 = correl.shape
correl_flatten = correl.ravel()
for i_flat in range(i0_start * n1 + i1_start, n0 * n1):
value = correl_flatten[i_flat]
if not np.isnan(value):
correl_max = value
break

assert not np.isnan(correl_max)

i0_max = 0
i1_max = 0

for i0 in range(i0_start, i0_stop):
for i1 in range(i1_start, i1_stop):
value = correl[i0, i1]
if np.isnan(value):
continue
if value >= correl_max:
correl_max = value
i0_max = i0
i1_max = i1

return i0_max, i1_max


def _compute_indices_max(
correl, norm, start_stop_for_search0, start_stop_for_search1
):
"""Compute the indices of the maximum correlation
Warning: important for perf (~25% for PIV)
"""
i0_start, i0_stop = start_stop_for_search0
i1_start, i1_stop = start_stop_for_search1

if i0_stop is None:
i0_stop, i1_stop = correl.shape

iy, ix = nan_indices_max(correl, i0_start, i0_stop, i1_start, i1_stop)

# iy, ix = np.unravel_index(np.nanargmax(correl), correl.shape)

if norm == 0:
# I hope it is ok (Pierre)
Expand Down Expand Up @@ -133,10 +189,12 @@ def __init__(
self.particle_radius = particle_radius
self.nb_peaks_to_search = nb_peaks_to_search

self.start_stop_for_search0 = [0, None]
self.start_stop_for_search1 = [0, None]
self._init2()

def _init2(self):
pass
"""Finalize initialization"""

def compute_displacement_from_indices(self, ix, iy):
"""Compute the displacement from a couple of indices."""
Expand All @@ -148,11 +206,16 @@ def compute_indices_from_displacement(self, dx, dy):
def get_indices_no_displacement(self):
return self.iy0, self.ix0

def _compute_indices_max(self, correl, norm):
return _compute_indices_max(
correl, norm, self.start_stop_for_search0, self.start_stop_for_search1
)

def compute_displacements_from_correl(self, correl, norm=1.0):
"""Compute the displacement from a correlation."""

try:
ix, iy, correl_max = _compute_indices_max(correl, norm)
ix, iy, correl_max = self._compute_indices_max(correl, norm)
except PIVError as piv_error:
ix, iy, correl_max = piv_error.results
# second chance to find a better peak...
Expand All @@ -161,7 +224,7 @@ def compute_displacements_from_correl(self, correl, norm=1.0):
ix - self.particle_radius : ix + self.particle_radius + 1,
] = np.nan
try:
ix2, iy2, correl_max2 = _compute_indices_max(correl, norm)
ix2, iy2, correl_max2 = self._compute_indices_max(correl, norm)
except PIVError as _piv_error:
dx, dy = self.compute_displacement_from_indices(ix, iy)
_piv_error.results = (dx, dy, correl_max)
Expand All @@ -183,7 +246,9 @@ def compute_displacements_from_correl(self, correl, norm=1.0):
ix - self.particle_radius : ix + self.particle_radius + 1,
] = np.nan
try:
ix, iy, correl_max_other = _compute_indices_max(correl, norm)
ix, iy, correl_max_other = self._compute_indices_max(
correl, norm
)
except PIVError:
break

Expand Down Expand Up @@ -308,7 +373,7 @@ def correl_numpy(im0: A, im1: A, disp_max: int):

class CorrelPythran(CorrelBase):
"""Correlation using pythran.
Correlation class by hands with with numpy.
Correlation class by hands with numpy.
"""

_tag = "pythran"
Expand Down Expand Up @@ -636,6 +701,25 @@ def _init2(self):

self.where_large_displacement = where_large_displacement

n0, n1 = where_large_displacement.shape
for i0_start in range(n0):
if not all(where_large_displacement[i0_start, :]):
break
for i1_start in range(n1):
if not all(where_large_displacement[:, i1_start]):
break
for i0_stop in range(n0 - 1, -1, -1):
if not all(where_large_displacement[i0_stop, :]):
break
i0_stop += 1
for i1_stop in range(n1 - 1, -1, -1):
if not all(where_large_displacement[:, i1_stop]):
break
i1_stop += 1

self.start_stop_for_search0 = (i0_start, i0_stop)
self.start_stop_for_search1 = (i1_start, i1_stop)

self._check_im_shape(self.im0_shape, self.im1_shape)

def _check_im_shape(self, im0_shape, im1_shape):
Expand Down Expand Up @@ -668,7 +752,7 @@ def __call__(self, im0, im1):
norm = np.sqrt(np.sum(im1**2) * np.sum(im0**2))
corr = ifft2(fft2(im0).conj() * fft2(im1)).real
correl = np.fft.fftshift(corr[::-1, ::-1])
return correl, norm
return np.ascontiguousarray(correl), norm


class CorrelFFTW(CorrelFFTBase):
Expand Down

0 comments on commit 57e5484

Please sign in to comment.