Skip to content

Commit

Permalink
Fix 'all' in load_headers
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyKozhevin committed Sep 5, 2024
1 parent be93d32 commit c586c63
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
28 changes: 8 additions & 20 deletions segfast/memmap_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,11 @@ def load_headers(self, headers, indices=None, reconstruct_tsf=True, sort_columns
dst_headers_dtype = np.dtype(dst_headers_dtype).newbyteorder("=")

# Calculate the number of requested traces, chunks and a list of trace indices/slices for each chunk
if indices is None:
n_traces = self.n_traces
n_chunks, last_chunk_size = divmod(n_traces, chunk_size)
if last_chunk_size:
n_chunks += 1
chunk_indices_list = [slice(i * chunk_size, (i + 1) * chunk_size) for i in range(n_chunks)]
else:
n_traces = len(indices)
n_chunks, last_chunk_size = divmod(n_traces, chunk_size)
if last_chunk_size:
n_chunks += 1
chunk_indices_list = np.array_split(indices, n_chunks)
n_traces = self.n_traces if indices is None else len(indices)
n_chunks, last_chunk_size = divmod(n_traces, chunk_size)
if last_chunk_size:
n_chunks += 1
chunk_slices_list = [slice(i * chunk_size, (i + 1) * chunk_size) for i in range(n_chunks)]

# Process `max_workers` and select executor
max_workers = os.cpu_count() if max_workers is None else max_workers
Expand All @@ -203,24 +196,19 @@ def load_headers(self, headers, indices=None, reconstruct_tsf=True, sort_columns

with Notifier(pbar, total=n_traces) as progress_bar:
with executor_class(max_workers=max_workers) as executor:
start = 0
def callback(future, start):
chunk_headers = future.result()
chunk_size = len(chunk_headers)
buffer[start : start + chunk_size] = chunk_headers
progress_bar.update(chunk_size)

for i, chunk_indices in enumerate(chunk_indices_list):
for i, chunk_slice in enumerate(chunk_slices_list):
chunk_indices = indices[chunk_slice] if indices is not None else chunk_slice
future = executor.submit(read_chunk, path=self.path, shape=self.n_traces,
offset=self.file_traces_offset, mmap_dtype=mmap_trace_dtype,
buffer_dtype=dst_headers_dtype, headers=headers, indices=chunk_indices)

future.add_done_callback(partial(callback, start=start))

if isinstance(chunk_indices, slice):
start += chunk_size
else:
start += len(chunk_indices)
future.add_done_callback(partial(callback, start=i * chunk_size))

# Convert to pd.DataFrame, optionally add TSF and sort
dataframe = pd.DataFrame(buffer, copy=False)
Expand Down
2 changes: 1 addition & 1 deletion segfast/segyio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def make_headers_specs(self, headers):
byteorder = self.ENDIANNESS_TO_SYMBOL[self.endian]

if headers == 'all':
return [TraceHeaderSpec(start_byte, byteorder=byteorder)
return [TraceHeaderSpec(start_byte=start_byte, byteorder=byteorder)
for start_byte in TraceHeaderSpec.STANDARD_BYTE_TO_HEADER]

headers_ = []
Expand Down

0 comments on commit c586c63

Please sign in to comment.