Skip to content

Commit

Permalink
Adapt how the ray class works; Massively speed up scattering
Browse files Browse the repository at this point in the history
  • Loading branch information
jgray-19 committed Feb 12, 2024
1 parent 3787aab commit 2bb61c7
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 181 deletions.
23 changes: 12 additions & 11 deletions src/detectors/curved_detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
pixel_height (1, 1) double % The height of each pixel in z
end
methods (Access=private, Static)
function generator = get_ray_generator_static(ny_pixels, nz_pixels, pixel_angle, pixel_height, dist_to_detector, to_source_vec, source_pos)
function generator = get_ray_generator_static(ny_pixels, nz_pixels, pixel_angle, pixel_height, dist_to_detector, to_source_vec, source_pos, voxels)
generator = @get_ray_attrs; % Create the function which returns the rays
function [xray] = get_ray_attrs(y_pixel, z_pixel)
assert(y_pixel <= ny_pixels && y_pixel > 0 && z_pixel <= nz_pixels && z_pixel > 0, ...
Expand All @@ -15,7 +15,7 @@
final_length = sqrt(dist_to_detector.^2 + z_shift.^2);
pixel_vec = (rotz(pixel_angle * (y_pixel - (ny_pixels+1)/2)) * to_source_vec.*dist_to_detector - ...
[0;0;z_shift]) ./ final_length;
xray = ray(source_pos, -pixel_vec, final_length);
xray = ray(source_pos, -pixel_vec, final_length, voxels);
end
end
end
Expand Down Expand Up @@ -52,30 +52,32 @@ function reset(self)
self.source_position = self.init_source_pos;
end

function ray_generator = get_ray_generator(self, ray_per_pixel)
function ray_generator = get_ray_generator(self, voxels, ray_per_pixel)
% Create a function which returns the rays which should be fired to hit each pixel.
% Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.
arguments
self curved_detector
voxels voxel_array
ray_per_pixel int32 = 1
end
assert(nargin==1, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")
assert(nargin==2, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")

% Create the function which returns the rays
ray_generator = self.get_ray_generator_static(...
self.ny_pixels, self.nz_pixels, self.pixel_angle, self.pixel_height, ...
self.dist_to_detector, self.to_source_vec, self.source_position...
self.dist_to_detector, self.to_source_vec, self.source_position, voxels...
);
end
function pixel_generator = get_pixel_generator(self, angle_index, ray_per_pixel)
function pixel_generator = get_pixel_generator(self, angle_index, voxels, ray_per_pixel)
% Create a function which returns the rays which should be fired to hit each pixel.
% Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.
arguments
self curved_detector
angle_index double
voxels voxel_array
ray_per_pixel int32 = 1
end
assert(nargin==2, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")
assert(nargin==3, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")

if angle_index == 1; rot_mat = eye(3);
else ; rot_mat = rotz(self.rot_angle * (angle_index - 1));
Expand All @@ -87,11 +89,10 @@ function reset(self)
pixel_generator = @generator;
static_ray_generator = curved_detector.get_ray_generator_static(...
self.ny_pixels, self.nz_pixels, self.pixel_angle, self.pixel_height, ...
self.dist_to_detector, to_source_vec, source_pos...
self.dist_to_detector, to_source_vec, source_pos, voxels...
);
function pixel_value = generator(y_pixel, z_pixel, voxels)
xray = static_ray_generator(y_pixel, z_pixel);
pixel_value = xray.calculate_mu(voxels);
function pixel_value = generator(y_pixel, z_pixel)
pixel_value = static_ray_generator(y_pixel, z_pixel).calculate_mu();
end
end

Expand Down
66 changes: 33 additions & 33 deletions src/detectors/detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
rotate(self)
reset(self)
hit_pixel(self, ray)
get_ray_generator(self, ray_type, ray_per_pixel)
get_pixel_generator(self, angle_index, ray_type, ray_per_pixel)
get_ray_generator(self, voxels, ray_type, ray_per_pixel)
get_pixel_generator(self, angle_index, voxels, ray_type, ray_per_pixel)
end

methods
Expand Down Expand Up @@ -87,12 +87,12 @@ function check_voxels(self, voxels)
self.check_voxels(voxels);
image = zeros(self.ny_pixels, self.nz_pixels, self.num_rotations);
for i = 1:self.num_rotations
ray_generator = self.get_ray_generator();
ray_generator = self.get_ray_generator(voxels);
for j = 1:self.nz_pixels
for k = 1:self.ny_pixels
ray = ray_generator(k, j);

mu = ray.calculate_mu(voxels);
mu = ray.calculate_mu();

image(k, j, i) = mu;
end
Expand All @@ -111,10 +111,10 @@ function check_voxels(self, voxels)
image = zeros(self.ny_pixels, self.nz_pixels, self.num_rotations);
get_pixel_generator = @self.get_pixel_generator;
for k = 1:self.num_rotations
pixel_calc = get_pixel_generator(k);
pixel_calc = get_pixel_generator(k, voxels);
for j = 1:self.nz_pixels
parfor i = 1:self.ny_pixels
image(i, j, k) = feval(pixel_calc, i, j, voxels);
image(i, j, k) = feval(pixel_calc, i, j);
end
end
end
Expand All @@ -131,7 +131,7 @@ function check_voxels(self, voxels)
elseif self.scatter_type == 1 % Fast scatter
scatter = self.conv_scatter(image);
elseif self.scatter_type == 2 % Slow scatter
scatter = self.slow_scatter_p(voxels);
scatter = (self.slow_scatter_p(voxels)-image) / 2;
end
end

Expand All @@ -141,12 +141,12 @@ function check_voxels(self, voxels)
air = voxel_object(@(i,j,k) i==i, material_attenuation("air"));
array = voxel_array(zeros(3, 1), zeros(3,1)+1e6, dtd/10, air);
scan = zeros(self.ny_pixels, self.nz_pixels, self.num_rotations);
ray_generator = self.get_ray_generator();
ray_generator = self.get_ray_generator(array);
image_at_angle = zeros(self.ny_pixels, self.nz_pixels);
for k = 1:self.nz_pixels
for j = 1:self.ny_pixels
ray = ray_generator(j, k);
mu = ray.calculate_mu(array);
mu = ray.calculate_mu();
image_at_angle(j, k) = mu;
end
end
Expand Down Expand Up @@ -178,26 +178,25 @@ function check_voxels(self, voxels)
% Do some Monte Carlo simulation of scatter
ny = self.ny_pixels; nz = self.nz_pixels;
scatter = zeros(ny, nz, self.num_rotations);
for sample = 1:self.scatter_factor
for k = 1:self.num_rotations
pixel_calc = self.get_pixel_generator(k, @scatter_ray);
scatter_idxs = zeros(ny, nz, 2); scatter_vals = zeros(ny, nz);
for j = 1:nz
parfor i = 1:ny
[pval, pixel, ~] = feval(pixel_calc, i, j, voxels);
if pval < inf % Collect the scatter values for adding later
scatter_idxs(i, j, :) = pixel;
scatter_vals(i, j) = pval;
end
end
for k = 1:self.num_rotations
pixel_calc = self.get_pixel_generator(k, voxels, @scatter_ray);
scatter_idxs = zeros(ny, nz, self.scatter_factor, 2);
scatter_vals = zeros(ny, nz, self.scatter_factor);
for j = 1:nz
parfor i = 1:ny
[pval, pixel, ~] = feval(pixel_calc, i, j);
scatter_idxs(i, j, :, :) = pixel;
scatter_vals(i, j, :) = pval;
end
end

% Add the scatter to the image
% Add the scatter to the image
for sf = 1:self.scatter_factor
for j = 1:nz
for i = 1:ny
if scatter_vals(i, j) > 0
scatter(scatter_idxs(i, j, 1), scatter_idxs(i, j, 2), k) = ...
scatter(scatter_idxs(i, j, 1), scatter_idxs(i, j, 2), k) + scatter_vals(i, j);
if ~isnan(scatter_vals(i, j, sf))
scatter(scatter_idxs(i, j, sf, 1), scatter_idxs(i, j, sf, 2), k) = ...
scatter(scatter_idxs(i, j, sf, 1), scatter_idxs(i, j, sf, 2), k) + scatter_vals(i, j, sf);
end
end
end
Expand All @@ -211,14 +210,15 @@ function check_voxels(self, voxels)
self.reset(); % Reset the detector to the initial position
ny = self.ny_pixels; nz = self.nz_pixels;
scatter = zeros(ny, nz, self.num_rotations);
for sample = 1:self.scatter_factor
for k = 1:self.num_rotations
pixel_calc = self.get_pixel_generator(k, @scatter_ray);
for j = 1:nz
for i = 1:ny
[pval, pixel, ~] = pixel_calc(i, j, voxels);
if pval < inf % Collect the attenuation values
scatter(pixel(1), pixel(2), k) = scatter(pixel(1), pixel(2), k) + pval;
for k = 1:self.num_rotations
pixel_calc = self.get_pixel_generator(k, voxels, @scatter_ray);
for j = 1:nz
for i = 1:ny
[pval, pixel, ~] = pixel_calc(i, j);
for sf = 1:self.scatter_factor
if ~isnan(pval(sf))
scatter(pixel(sf, 1), pixel(sf, 2), k) = ...
scatter(pixel(sf, 1), pixel(sf, 2), k) + pval(sf);
end
end
end
Expand Down
54 changes: 33 additions & 21 deletions src/detectors/parallel_detector.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
detector_vec (3, 1) double = [1;0;0] % Vector from left to right corner of detector
end
methods (Access=private, Static)
function generator = get_ray_generator_static(ray, ny_pixels, nz_pixels, centre, detector_vec, pixel_width, pixel_height, dist_to_detector, to_source_vec)
function generator = get_ray_generator_static(ray, ny_pixels, nz_pixels, centre, detector_vec, pixel_width, pixel_height, dist_to_detector, to_source_vec, voxels)
generator = @get_ray_attrs; % Create the function which returns the rays
function [xray] = get_ray_attrs(y_pixel, z_pixel)
assert(y_pixel <= ny_pixels && y_pixel > 0 && z_pixel <= nz_pixels && z_pixel > 0, ...
Expand All @@ -17,7 +17,7 @@
detector_vec .* (y_pixel - (ny_pixels+1)/2) .* pixel_width + ...
[0;0;pixel_height] .* (z_pixel - (nz_pixels+1)/2);
source_position = pixel_centre + to_source_vec .* dist_to_detector;
xray = ray(source_position, -to_source_vec, dist_to_detector);
xray = ray(source_position, -to_source_vec, dist_to_detector, voxels);
end
end
end
Expand Down Expand Up @@ -54,31 +54,34 @@ function reset(self)
self.centre = self.init_centre;
end

function ray_generator = get_ray_generator(self, ray_type, ray_per_pixel)
function ray_generator = get_ray_generator(self, voxels, ray_type, ray_per_pixel)
% Create a function which returns the rays which should be fired to hit each pixel.
% Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.
arguments
self parallel_detector
voxels voxel_array
ray_type = @ray
ray_per_pixel int32 = 1
end
assert(nargin<3, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")
assert(nargin<4, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")
% Create the function which returns the rays
ray_generator = parallel_detector.get_ray_generator_static(...
ray_type, self.ny_pixels, self.nz_pixels, self.centre, self.detector_vec, ...
self.pixel_dims(1), self.pixel_dims(2), self.dist_to_detector, self.to_source_vec...
self.pixel_dims(1), self.pixel_dims(2), self.dist_to_detector, ...
self.to_source_vec, voxels...
);
end
function pixel_generator = get_pixel_generator(self, angle_index, ray_type, ray_per_pixel)
function pixel_generator = get_pixel_generator(self, angle_index, voxels, ray_type, ray_per_pixel)
% Create a function which returns the rays which should be fired to hit each pixel.
% Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.
arguments
self parallel_detector
angle_index double
voxels voxel_array
ray_type = @ray
ray_per_pixel int32 = 1
end
assert(nargin<4, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")
assert(nargin<5, "Only 1 ray per pixel is supported at the moment, as anti-aliasing techniques are not yet implemented.")

if angle_index == 1; rot_mat = eye(3);
else ; rot_mat = rotz(self.rot_angle * (angle_index - 1));
Expand All @@ -90,7 +93,8 @@ function reset(self)
% Create the function which returns the rays
static_ray_generator = parallel_detector.get_ray_generator_static(...
ray_type, self.ny_pixels, self.nz_pixels, current_c, current_dv, ...
self.pixel_dims(1), self.pixel_dims(2), self.dist_to_detector, current_sv...
self.pixel_dims(1), self.pixel_dims(2), self.dist_to_detector, ...
current_sv, voxels...
);

ray_type_str = func2str(ray_type); % Get the name of the ray type
Expand All @@ -99,24 +103,32 @@ function reset(self)
else; error('parallel_detector:InvalidRayType', "Must be either 'ray' or 'scatter_ray'.");
end

function pixel_value = generator(y_pixel, z_pixel, voxels)
function pixel_value = generator(y_pixel, z_pixel)
xray = static_ray_generator(y_pixel, z_pixel);
pixel_value = xray.calculate_mu(voxels);
pixel_value = xray.calculate_mu();
end

function [pixel_value, pixel, scattered] = scatter_generator(y_pixel, z_pixel, voxels)
function [pixel_values, pixels, scattered] = scatter_generator(y_pixel, z_pixel)
xray = static_ray_generator(y_pixel, z_pixel);
xray = xray.calculate_mu(voxels);

mu = xray.mu; scattered = xray.scatter_event > 0;

hit = true;
if scattered; [pixel, hit] = self.hit_pixel(xray, current_dv);
else; pixel = [y_pixel, z_pixel];
end
pixel_values = zeros(self.scatter_factor, 1);
pixels = zeros(self.scatter_factor, 2);
scattered = zeros(self.scatter_factor, 1);

if hit; pixel_value = mu;
else; pixel_value = inf;
for i = 1:self.scatter_factor
new_ray = xray.calculate_mu();

this_scatter = new_ray.scatter_event > 0;

hit = true;
if this_scatter; [pixel, hit] = self.hit_pixel(new_ray, current_dv);
else; pixel = [y_pixel, z_pixel];
end
pixels(i, :) = pixel;
scattered(i) = this_scatter;
if hit; pixel_values(i) = new_ray.mu;
else; pixel_values(i) = NaN;
end
% xray = xray.randomise_n_mfp();
end
end
end
Expand Down
34 changes: 20 additions & 14 deletions src/ray_tracing/ray.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,36 @@
start_point (3, 1) double % 3D point
v1_to_v2 (3, 1) double % Vector from start_point to end_point
energy double % energy of the ray in KeV

voxels voxel_array % The voxel array the ray will be traced through
mu_dict % Dictionary of mu values for each material
lengths % Length of the ray in each voxel
indices % Indices of the voxels the ray intersects
end

properties (Access=private, Constant)
use_mex = ~~exist('ray_trace_mex', 'file'); % Use the MEX implementation of the photon_attenuation package if available
use_mex = ~~exist('ray_trace_mex', 'file'); % Use the MEX implementation of the ray tracing if available
end

methods
function self = ray(start_point, direction, dist_to_detector, energy)
function self = ray(start_point, direction, dist_to_detector, voxels, energy)
arguments
start_point (3, 1) double
direction (3, 1) double
dist_to_detector double
voxels voxel_array
energy double = 30 %KeV
end
self.start_point = start_point;
self.v1_to_v2 = direction .* dist_to_detector;

self.energy = energy;

% Using the assumption that the ray's energy is constant
% Create a dictionary of the values of mu for each material
self.mu_dict = voxels.get_mu_dict(energy);
[self.lengths, self.indices] = self.get_intersections(voxels);
self.voxels = voxels;
end

function [lengths, indices] = get_intersections(self, voxels)
Expand All @@ -41,21 +53,15 @@
end
end

function mu = calculate_mu (self, voxels)
arguments
self ray
voxels voxel_array
end
% Using the assumption that the ray's energy is constant
% Create a dictionary of the values of mu for each material
mu_dict = voxels.get_mu_dict(self.energy);

function mu = calculate_mu (self)
% I don't see the point of this function if we precalculate everything else
% Maybe for testing and consistency with the scattered ray class

% Calculate the mu of the ray
[lengths, indices] = self.get_intersections(voxels);
if isempty(lengths) % No intersections
if isempty(self.lengths) % No intersections
mu = 0;
else
mu = sum(lengths .* voxels.get_saved_mu(indices, mu_dict));
mu = sum(self.lengths .* self.voxels.get_saved_mu(self.indices, self.mu_dict));
end
end
end
Expand Down
Loading

0 comments on commit 2bb61c7

Please sign in to comment.