Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix profile access in groupwise registration #444

Merged
merged 5 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release/release_v1.6.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

- The atlas transformer (`atlas_refiner.transpose_img`) provides a more comprehensive set of typical transformations before atlas refinement or registration, such as rotation to any angle, flipping along any axis, and resizing (#195, #214)
- Edge/perimeter thickness can be customized (#307)
- Fixed groupwise registration for current atlas profiles, turned off default cropping (#444)

#### Atlas registration

Expand Down
77 changes: 45 additions & 32 deletions magmap/atlas/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,26 +989,33 @@ def _crop_image(img_np, labels_img, axis, eraser=None):
return img_crop, i


def register_group(img_files, rotate=None, show_imgs=True,
write_imgs=True, name_prefix=None, scale=None):
def register_group(
img_files: Sequence[str], rotate: Optional[Sequence[int]] = None,
show_imgs: bool = True, write_imgs: bool = True,
name_prefix: Optional[str] = None, scale: Optional[float] = None):
"""Group registers several images to one another.

Uses the first channel in :attr:`config.channel` or the first channel
in each image.

Registration parameters are assumed to be in a "b-spline"
:class:`magmap.settings.atlas_prof.RegParamMap`.

Args:
img_files: Paths to image files to register.
rotate (List[int]): List of number of 90 degree rotations for images
img_files: Paths to image files to register. A minimum of 4 images
is required for groupwise registration.
rotate: List of number of 90 degree rotations for images
corresponding to ``img_files``; defaults to None, in which
case the `config.transform` rotate attribute will be used.
show_imgs: True if the output images should be displayed; defaults to
show_imgs: True if the output images should be displayed; defaults to
True.
write_imgs: True if the images should be written to file; defaults to
write_imgs: True if the images should be written to file; defaults to
True.
name_prefix: Path with base name where registered files will be output;
name_prefix: Path with base name where registered files will be output;
defaults to None, in which case the fixed_file path will be used.
scale: Rescaling factor as a scalar value, used to find the rescaled,
scale: Rescaling factor as a scalar value, used to find the rescaled,
smaller images corresponding to ``img_files``. Defaults to None.

"""
start_time = time()
if name_prefix is None:
Expand All @@ -1018,15 +1025,15 @@ def register_group(img_files, rotate=None, show_imgs=True,
target_size = config.atlas_profile["target_size"]

'''
# TESTING: assuming first file is a raw groupwise registered image,
# TESTING: assuming first file is a raw groupwise registered image,
# import it for post-processing
img = sitk.ReadImage(img_files[0])
img_np = sitk.GetArrayFromImage(img)
print("thresh mean: {}".format(filters.threshold_mean(img_np)))
carve_threshold = config.register_settings["carve_threshold"]
holes_area = config.register_settings["holes_area"]
img_np, img_np_unfilled = plot_3d.carve(
img_np, thresh=carve_threshold, holes_area=holes_area,
img_np, thresh=carve_threshold, holes_area=holes_area,
return_unfilled=True)
sitk.Show(sitk_io.replace_sitk_with_numpy(img, img_np_unfilled))
sitk.Show(sitk_io.replace_sitk_with_numpy(img, img_np))
Expand All @@ -1053,25 +1060,29 @@ def register_group(img_files, rotate=None, show_imgs=True,
img_np = sitk.GetArrayFromImage(img)
if img_np_template is None:
img_np_template = np.copy(img_np)

# crop y-axis based on registered labels to ensure that sample images
# have the same structures since variable amount of tissue posteriorly;
# cropping appears to work better than erasing for groupwise reg,
# preventing some images from being stretched into the erased space
labels_img = sitk_io.load_registered_img(
img_files[i], config.RegNames.IMG_LABELS_TRUNC.value)
img_np, y_cropped = _crop_image(img_np, labels_img, 1)#, eraser=0)

y_cropped = 0
try:
# crop y-axis based on registered labels so that sample images,
# which appears to work better than erasing for groupwise reg by
# preventing some images from being stretched into the erased space
labels_img = sitk_io.load_registered_img(
img_files[i], config.RegNames.IMG_LABELS_TRUNC.value)
_logger.info("Cropping image based on labels trunction image")
img_np, y_cropped = _crop_image(img_np, labels_img, 1)#, eraser=0)
except FileNotFoundError:
pass
'''
# crop anterior region
rotated = np.rot90(img_np, 2, (1, 2))
rotated, _ = _crop_image(rotated, np.rot90(labels_img, 2, (1, 2)), 1)
img_np = np.rot90(rotated, 2, (1, 2))
'''

# force all images into same size and origin as first image
# force all images into same size and origin as first image
# to avoid groupwise registration error on physical space mismatch
if size_cropped is not None:
# use default interpolation, but should change to nearest neighbor
# use default interpolation, but should change to nearest neighbor
# if using for labels
img_np = transform.resize(
img_np, size_cropped[::-1], anti_aliasing=True, mode="reflect")
Expand All @@ -1085,9 +1096,9 @@ def register_group(img_files, rotate=None, show_imgs=True,
start_y = y_cropped
print("size_cropped: ", size_cropped, ", size_orig", size_orig)
else:
# force images into space of first image; may not be exactly
# correct but should be close since resized to match first image,
# and spacing of resized images and atlases largely ignored in
# force images into space of first image; may not be exactly
# correct but should be close since resized to match first image,
# and spacing of resized images and atlases largely ignored in
# favor of comparing shapes of large original and registered images
img.SetOrigin(origin)
img.SetSpacing(spacing)
Expand All @@ -1098,19 +1109,21 @@ def register_group(img_files, rotate=None, show_imgs=True,
#sitk.ProcessObject.SetGlobalDefaultCoordinateTolerance(100)
img_combined = sitk.JoinSeries(img_vector)

# add b-spline registration parameter map
settings = config.atlas_profile
reg = settings["reg_bspline"]
elastix_img_filter = sitk.ElastixImageFilter()
elastix_img_filter.SetFixedImage(img_combined)
elastix_img_filter.SetMovingImage(img_combined)
param_map = sitk.GetDefaultParameterMap("groupwise")
param_map["FinalGridSpacingInVoxels"] = [
settings["bspline_grid_space_voxels"]]
del param_map["FinalGridSpacingInPhysicalUnits"] # avoid conflict with vox
param_map["MaximumNumberOfIterations"] = [settings["groupwise_iter_max"]]
reg["grid_space_voxels"]]
del param_map["FinalGridSpacingInPhysicalUnits"] # avoid conflict with vox
param_map["MaximumNumberOfIterations"] = [reg["max_iter"]]
# TESTING:
#param_map["MaximumNumberOfIterations"] = ["0"]
_config_reg_resolutions(
settings["grid_spacing_schedule"], param_map, img_np_template.ndim)
reg["grid_spacing_schedule"], param_map, img_np_template.ndim)
elastix_img_filter.SetParameterMap(param_map)
elastix_img_filter.PrintParameterMap()
transform_filter = elastix_img_filter.Execute()
Expand All @@ -1124,10 +1137,10 @@ def register_group(img_files, rotate=None, show_imgs=True,
imgs = []
num_images = len(img_files)
for i in range(num_images):
extract_filter.SetIndex([0, 0, 0, i]) # x, y, z, t
extract_filter.SetIndex([0, 0, 0, i]) # x, y, z, t
img = extract_filter.Execute(transformed_img)
img_np = sitk.GetArrayFromImage(img)
# resize to original shape of first image, all aligned to position
# resize to original shape of first image, all aligned to position
# of subject within first image
img_large_np = np.zeros(size_orig[::-1])
img_large_np[:, start_y:start_y+img_np.shape[1]] = img_np
Expand All @@ -1140,7 +1153,7 @@ def register_group(img_files, rotate=None, show_imgs=True,
extend_borders = settings["extend_borders"]
carve_threshold = settings["carve_threshold"]
if extend_borders and carve_threshold:
# merge in specified border region from first image for pixels below
# merge in specified border region from first image for pixels below
# carving threshold to prioritize groupwise image
slices = []
for border in extend_borders[::-1]:
Expand All @@ -1158,7 +1171,7 @@ def register_group(img_files, rotate=None, show_imgs=True,
holes_area = settings["holes_area"]
if carve_threshold and holes_area:
img_mean, _, img_mean_unfilled = cv_nd.carve(
img_mean, thresh=carve_threshold, holes_area=holes_area,
img_mean, thresh=carve_threshold, holes_area=holes_area,
return_unfilled=True)
img_unfilled = sitk_io.replace_sitk_with_numpy(
transformed_img, img_mean_unfilled)
Expand All @@ -1173,7 +1186,7 @@ def register_group(img_files, rotate=None, show_imgs=True,

#transformed_img = img_raw
if write_imgs:
# write both the .mhd and Numpy array files to a separate folder to
# write both the .mhd and Numpy array files to a separate folder to
# mimic the atlas folder format
out_path = os.path.join(name_prefix, config.RegNames.IMG_GROUPED.value)
if not os.path.exists(name_prefix):
Expand Down
12 changes: 11 additions & 1 deletion magmap/io/sitk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,17 @@ def replace_sitk_with_numpy(
# transfer original settings to new sitk Image
img_sitk_back.SetSpacing(spacing)
img_sitk_back.SetOrigin(origin)
img_sitk_back.SetDirection(direction)

try:
img_sitk_back.SetDirection(direction)
except RuntimeError:
# direction format and length may not be directly transferable, such
# as direction from groupwise reg image
_logger.warn(
"Could not replace image direction with: %s\n"
"Leaving default direction: %s",
direction, img_sitk_back.GetDirection())

return img_sitk_back


Expand Down