Skip to content

Commit

Permalink
fixed bugs on class definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
azimonti committed Sep 3, 2024
1 parent e6673d8 commit ef4b253
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
44 changes: 29 additions & 15 deletions double_slit_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def __init__(self, outfile, wavepacket: WavepacketSimulation):
# initialize variables
self.perc = None
self.start_time = None
self.x_min = wavepacket.x_min
self.x_max = wavepacket.x_max
self.y_min = wavepacket.y_min
self.y_max = wavepacket.y_max
self.dx = wavepacket.dx
self.dy = wavepacket.dy
self.num_frames = int(p.total_duration * p.fps)
self.Nx = wavepacket.Nx
self.Ny = wavepacket.Ny

@property
def outfile(self):
Expand All @@ -47,17 +56,20 @@ def outfile(self):
def outfile(self, value):
self._outfile = value

def introspection(self):
print(f"Number of points: {len(self.screen_data_total)}")

def line_cells_crossed(self):
self.crossed_nx = []
self.crossed_ny = []
# Unpack capture data from the global cfg
(x1, y1), (x2, y2) = cfg.capture_data

# Calculate the grid coordinates of the endpoints
x1_idx = int((x1 - self.wp.x_min) // self.wp.dx)
y1_idx = int((y1 - self.wp.y_min) // self.wp.dy)
x2_idx = int((x2 - self.wp.x_min) // self.wp.dx)
y2_idx = int((y2 - self.wp.y_min) // self.wp.dy)
x1_idx = int((x1 - self.x_min) // self.dx)
y1_idx = int((y1 - self.y_min) // self.dy)
x2_idx = int((x2 - self.x_min) // self.dx)
y2_idx = int((y2 - self.y_min) // self.dy)

# Bresenham's line algorithm adapted to this grid
cells_crossed = []
Expand All @@ -75,8 +87,8 @@ def line_cells_crossed(self):
self.crossed_nx.append(x)
self.crossed_ny.append(y)
# compute cell mid point
x_mid = self.wp.x_min + x * self.wp.dx + 0.5 * self.wp.dx
y_mid = self.wp.y_min + y * self.wp.dy + 0.5 * self.wp.dy
x_mid = self.x_min + x * self.dx + 0.5 * self.dx
y_mid = self.y_min + y * self.dy + 0.5 * self.dy
cell = (x_mid, y_mid)
if cell in seen_cells:
raise ValueError(f"Duplicate cell detected at {cell}")
Expand All @@ -101,7 +113,7 @@ def compute(self):
# Loop over each snapshot in psi_plot
for psi in self.wp.psi_plot:
# Reshape the 1D wavefunction array to 2D
psi = psi.reshape(self.wp.Ny, self.wp.Nx)
psi = psi.reshape(self.Ny, self.Nx)
# initialize a temporary array to accumulate data for this snapshot
temp_data = np.zeros(len(self.crossed_nx))
# loop over each cell in the crossed path
Expand Down Expand Up @@ -130,14 +142,14 @@ def __init_plot(self):
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
# the screen is assumed vertical in the y direction
ax.set_xlim(self.wp.y_min, self.wp.y_max)
ax.set_xlim(self.y_min, self.y_max)
ax.set_ylim(0, max(self.screen_data_total) * 1.01)
ax.set_yticklabels([])
ax.set_xticklabels([])
# init the total data
self.screen_data_total_tmp = np.zeros(len(self.crossed_nx))
# Calculate real distances along the y-axis
real_distances = self.wp.y_min + np.array(self.crossed_ny) * self.wp.dy
real_distances = self.y_min + np.array(self.crossed_ny) * self.dy
# Sort by real distances
self.sorted_indices = np.argsort(real_distances)
self.sorted_distances = real_distances[self.sorted_indices]
Expand All @@ -159,16 +171,16 @@ def __init_plot2(self):
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
# Convert grid indices to real x and y coordinates
self.crossed_x = [self.wp.x_min + nx * self.wp.dx + 0.5 * self.wp.dx
self.crossed_x = [self.x_min + nx * self.dx + 0.5 * self.dx
for nx in self.crossed_nx]
self.crossed_y = [self.wp.y_min + ny * self.wp.dy + 0.5 * self.wp.dy
self.crossed_y = [self.y_min + ny * self.dy + 0.5 * self.dy
for ny in self.crossed_ny]
self.curve3 = ax.scatter(self.crossed_x, self.crossed_y,
c=np.zeros_like(self.crossed_x), cmap='hot',
vmin=0, vmax=max(self.screen_data_total))
# the screen is assumed vertical in the y direction
ax.set_xlim(self.wp.x_min, self.wp.x_max)
ax.set_ylim(self.wp.y_min, self.wp.y_max)
ax.set_xlim(self.x_min, self.x_max)
ax.set_ylim(self.y_min, self.y_max)
ax.set_yticklabels([])
ax.set_xticklabels([])
plt.tight_layout()
Expand All @@ -184,7 +196,7 @@ def __animate_frame(self, frame, is_animation=True, is_pngexport=False):
ptext = "the animation"
else:
ptext = "png export"
perc = (frame + 1) / self.wp.num_frames * 100
perc = (frame + 1) / self.num_frames * 100
if perc // 10 > self.perc // 10:
self.perc = perc
elapsed_time = time.time() - self.start_time
Expand All @@ -203,6 +215,8 @@ def __animate_frame(self, frame, is_animation=True, is_pngexport=False):
return (self.curve1,)

def plot(self, nframe=cfg.frame_id):
self.perc = 0
self.start_time = time.time()
self.__init_plot()
for n in range(nframe):
self.__animate_frame(n, False, True)
Expand All @@ -214,7 +228,7 @@ def animate(self):
self.start_time = time.time()
self.__init_plot()
anim = FuncAnimation(
self.fig, self.__animate_frame, frames=self.wp.num_frames,
self.fig, self.__animate_frame, frames=self.num_frames,
interval=1000 / p.fps, blit=True)
if cfg.save_anim:
base, ext = self._outfile.rsplit('.', 1)
Expand Down
15 changes: 5 additions & 10 deletions schrodinger_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys
import time

from mod_config_2d import cfg, p2, p2_changes_load
from mod_config_2d import cfg, p2, p2_changes_load_s2d
from mod_config import palette

if cfg.use_pickle:
Expand Down Expand Up @@ -82,7 +82,6 @@ def __init__(self, outfile, Nx, Ny,
# time parameters
self.dt = dt
self.t_max = t_max
# self.num_frames = int(t_max / dt)
self.num_frames = int(p.total_duration * p.fps)
tsteps = int(t_max / dt)
if tsteps < self.num_frames:
Expand Down Expand Up @@ -289,16 +288,12 @@ def compute(self):
def __init_plot(self):
plot_psi = self.psi_plot[0]
cgray = (0.83, 0.83, 0.83)
dpi = 300 if cfg.high_res_plot else 100
if cfg.fig_4k:
if cfg.high_res_plot:
self.fig, ax = plt.subplots(figsize=(12.8, 7.2), dpi=300)
else:
self.fig, ax = plt.subplots(figsize=(12.8, 7.2), dpi=300)
figsize = (3840 / dpi, 2160 / dpi)
else:
if cfg.high_res_plot:
self.fig, ax = plt.subplots(dpi=300)
else:
self.fig, ax = plt.subplots()
figsize = (1920 / dpi, 1080 / dpi)
self.fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# specific visualization option for absorbing boundaries
if not p.infinite_barrier:
Expand Down

0 comments on commit ef4b253

Please sign in to comment.