-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathct_viewer.py
179 lines (141 loc) · 7.75 KB
/
ct_viewer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import warnings
warnings.filterwarnings('ignore') # Ignore warnings
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, interactive_output
from IPython.display import display, clear_output
import time
# List of DICOM tags to display
DICOM_TAGS_TO_DISPLAY = ['patient_id', 'age']
def load_nifti_ras(file_path):
"""Load a NIfTI file and return the image data oriented in RAS+."""
img = nib.load(file_path)
# Get the image data and affine matrix
data = img.get_fdata()
affine = img.affine
# Determine the current orientation
current_ornt = nib.orientations.io_orientation(affine)
# Define the target orientation (RAS+)
ras_ornt = np.array([[0, 1], [1, 1], [2, 1]])
# Calculate the transformation to the target orientation
transform = nib.orientations.ornt_transform(current_ornt, ras_ornt)
# Apply the transformation to the data
return nib.orientations.apply_orientation(data, transform)
def clip_hu_values(ct_scan, min_hu, max_hu):
"""Clip the Hounsfield Unit (HU) values of the CT scan."""
ct_scan = np.clip(ct_scan, min_hu, max_hu)
return ct_scan
class CTScanViewer:
def __init__(self, df, ct_scan_col, segmentation_col, HU_min=-100, HU_max=400):
self.df = df # DataFrame containing scan data
self.ct_scan_col = ct_scan_col # Column name for CT scan file paths
self.segmentation_col = segmentation_col # Column name for segmentation file paths
self.HU_min = HU_min # Minimum HU value for clipping
self.HU_max = HU_max # Maximum HU value for clipping
self.current_index = 0 # Index of the current scan
self.view_plane = 'axial' # Initial view plane
self.slice_idx = 0 # Index of displayed slice
self.ct_scan = np.zeros([2, 2, 2]) # Initialize 3D array for CT scan
self.segmentation = np.zeros([2, 2, 2]) # Initialize 3D array for segmentation
self.init_widgets() # Initialize widgets
self.load_data() # Load the initial scan data
def init_widgets(self):
"""Initialize interactive widgets."""
self.slice_slider = widgets.IntSlider(
min=0, max=100, step=1, value=0, description='Slice ', layout=widgets.Layout(width='600px'))
self.slice_slider.observe(self.on_slice_change, names='value') # Update slice on slider change
self.alpha_slider = widgets.FloatSlider(value=0.3, min=0, max=1, step=0.1, description='α', orientation='vertical')
self.plane_selector = widgets.ToggleButtons(
options=['axial', 'sagittal', 'coronal'], description='Plane ')
self.plane_selector.observe(self.on_plane_change, names='value') # Update plane on selection change
self.next_button = widgets.Button(description="Next")
self.next_button.layout.object_position = 'right'
self.next_button.on_click(self.on_next) # Load next scan on button click
self.progress_bar = widgets.FloatProgress(
value=0, min=0, max=1, description='Loading:', bar_style='info')
self.info_display = widgets.HTML(value="") # HTML widget to display scan info
ui_top = widgets.VBox([self.plane_selector, self.slice_slider]) # Top UI elements
out = widgets.interactive_output(self.update_display, {'slice_idx': self.slice_slider, 'view_plane': self.plane_selector, 'alpha': self.alpha_slider})
ui_bot = widgets.HBox([out, self.alpha_slider, self.info_display, self.next_button, self.progress_bar]) # Bottom UI elements
display(ui_top, ui_bot) # Display the widgets
def load_data(self):
"""Load CT scan and segmentation data."""
self.progress_bar.layout.visibility = 'visible'
self.progress_bar.value = 0
self.progress_bar.bar_style = 'info'
self.progress_bar.description = 'Loading...'
row = self.df.iloc[self.current_index] # Get the current scan data
self.progress_bar.value = 0.1
self.ct_scan = load_nifti_ras(row[self.ct_scan_col]) # Load CT scan
self.progress_bar.value = 0.4
self.ct_scan = clip_hu_values(self.ct_scan, self.HU_min, self.HU_max) # Clip HU values
self.progress_bar.value = 0.6
self.segmentation = load_nifti_ras(row[self.segmentation_col]) # Load segmentation
self.progress_bar.value = 0.8
self.update_info_display() # Update scan info display
self.update_slice_slider() # Update the slice slider
self.progress_bar.value = 1
self.progress_bar.bar_style = 'success'
self.progress_bar.description = 'Loaded'
time.sleep(0.5)
self.progress_bar.layout.visibility = 'hidden'
def update_slice_slider(self):
"""Update the slice slider based on the selected view plane."""
if self.view_plane == 'axial':
self.num_slices = self.ct_scan.shape[2]
self.slice_idx = np.argmax(np.sum(self.segmentation, axis=(0, 1)))
elif self.view_plane == 'sagittal':
self.num_slices = self.ct_scan.shape[0]
self.slice_idx = np.argmax(np.sum(self.segmentation, axis=(1, 2)))
elif self.view_plane == 'coronal':
self.num_slices = self.ct_scan.shape[1]
self.slice_idx = np.argmax(np.sum(self.segmentation, axis=(0, 2)))
self.slice_slider.max = self.num_slices - 1
self.slice_slider.value = self.slice_idx
def update_display(self, slice_idx, view_plane, alpha=0.5):
"""Update the CT scan display based on the selected slice and view plane."""
self.view_plane = view_plane
if view_plane == 'axial':
ct_slice = self.ct_scan[:, :, slice_idx]
seg_slice = self.segmentation[:, :, slice_idx]
elif view_plane == 'sagittal':
ct_slice = self.ct_scan[slice_idx, :, :]
seg_slice = self.segmentation[slice_idx, :, :]
elif view_plane == 'coronal':
ct_slice = self.ct_scan[:, slice_idx, :]
seg_slice = self.segmentation[:, slice_idx, :]
plt.figure(figsize=(8, 8))
plt.imshow(ct_slice.T, cmap='gray', origin='lower')
plt.imshow(np.ma.masked_where(seg_slice == 0, seg_slice).T, cmap='jet', alpha=alpha, origin='lower')
plt.contour(seg_slice.T, colors='red', linewidths=0.5, alpha=alpha, origin='lower')
plt.show()
def update_info_display(self):
"""Update the scan info display."""
row = self.df.iloc[self.current_index]
info = f"<b>Scan Info:</b><br>"
for column in row.index:
if column in DICOM_TAGS_TO_DISPLAY:
info += f"<b>{column}:</b> {row[column]}<br>"
self.info_display.value = info
def on_slice_change(self, change):
"""Handle slice slider change event."""
self.slice_ix = self.slice_slider.value
def on_plane_change(self, change):
"""Handle view plane change event."""
self.view_plane = self.plane_selector.value # Update view plane
self.update_slice_slider() # Update the slice slider
def on_next(self, button):
"""Handle next button click event."""
self.current_index = (self.current_index + 1) % len(self.df) # Increment scan index
self.load_data() # Load the next scan data
# # Example usage
# df = pd.DataFrame({
# 'ct_scan_path': ['path_to_ct_scan1.nii', 'path_to_ct_scan2.nii'],
# 'segmentation_path': ['path_to_segmentation1.nii', 'path_to_segmentation2.nii'],
# 'patient_id': [1, 2],
# 'age': [65, 70],
# 'sex': ['M', 'F']
# })
# viewer = CTScanViewer(df, 'ct_scan_path', 'segmentation_path')