Skip to content

Commit

Permalink
Bug fixes and updates to LDMatrix data structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
shz9 committed Jun 6, 2024
1 parent ea3dc4b commit 63060d9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 28 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.1.4] - TBD

### Changed

- Updated the data type for the index pointer in the `LDMatrix` object to be `int64`. `int32` does
not work well for very large datasets with millions of variants and it causes overflow errors.
- Updated the way we determine the `pandas` chunksize when converting from `plink` tables to `zarr`.
- Simplified the way we compute the quantization scale in `model_utils`.

### Added

- Added extra validation checks in `LDMatrix` to ensure that the index pointer is formatted correctly.

## [0.1.3] - 2024-05-21

### Changed
Expand Down
55 changes: 38 additions & 17 deletions magenpy/LDMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def from_csr(cls,
mat.array('data', triu_mat.data.astype(dtype), dtype=dtype, compressor=compressor_name)

# Store the index pointer:
mat.array('indptr', triu_mat.indptr,
dtype=np.int32, compressor=compressor)
mat.array('indptr', triu_mat.indptr, dtype=np.int64, compressor=compressor)

return cls(z)

Expand Down Expand Up @@ -217,7 +216,7 @@ def from_plink_table(cls,
ld_chunks = [ld_chunks]

# Create a dictionary mapping SNPs to their indices:
snp_dict = dict(zip(snps, np.arange(len(snps))))
snp_idx = pd.Series(np.arange(len(snps), dtype=np.int32), index=snps)

indptr_counts = np.zeros(len(snps), dtype=np.int32)

Expand All @@ -227,7 +226,10 @@ def from_plink_table(cls,
for ld_chunk in ld_chunks:

# Create an indexed LD chunk:
ld_chunk['row_index'] = ld_chunk['SNP_A'].map(snp_dict)
row_index = snp_idx[ld_chunk['SNP_A'].values]

# Fill N/A in R before storing it:
ld_chunk['R'].fillna(0., inplace=True)

# Add LD data to the zarr array:
if np.issubdtype(dtype, np.integer):
Expand All @@ -237,16 +239,16 @@ def from_plink_table(cls,

total_len += len(ld_chunk)

# Group by the row index:
grouped_ridx = ld_chunk.groupby('row_index').size()
# Count the number of occurrences of each SNP in the chunk:
snp_counts = row_index.value_counts()

# Add the number of entries to indptr_counts:
indptr_counts[grouped_ridx.index] += grouped_ridx.values
indptr_counts[snp_counts.index] += snp_counts.values

# Get the final indptr by computing cumulative sum:
indptr = np.insert(np.cumsum(indptr_counts), 0, 0)
indptr = np.insert(np.cumsum(indptr_counts, dtype=np.int64), 0, 0)
# Store indptr in the zarr group:
mat.array('indptr', indptr, dtype=np.int32, compressor=compressor)
mat.array('indptr', indptr, dtype=np.int64, compressor=compressor)

# Resize the data array:
mat['data'].resize(total_len)
Expand Down Expand Up @@ -306,7 +308,7 @@ def from_dense_zarr_matrix(cls,
num_rows = dense_zarr.shape[0]
chunk_size = dense_zarr.chunks[0]

indptr_counts = np.zeros(num_rows, dtype=int)
indptr_counts = np.zeros(num_rows, dtype=np.int32)

total_len = 0

Expand Down Expand Up @@ -340,9 +342,9 @@ def from_dense_zarr_matrix(cls,
total_len += chunk_len

# Get the final indptr by computing cumulative sum:
indptr = np.insert(np.cumsum(indptr_counts), 0, 0)
indptr = np.insert(np.cumsum(indptr_counts, dtype=np.int64), 0, 0)
# Store indptr in the zarr array:
mat.array('indptr', indptr, compressor=compressor)
mat.array('indptr', indptr, dtype=np.int64, compressor=compressor)

# Resize the data and indices arrays:
mat['data'].resize(total_len)
Expand Down Expand Up @@ -405,7 +407,7 @@ def from_ragged_zarr_matrix(cls,
mat = z.create_group('matrix')
mat.empty('data', shape=num_rows ** 2, dtype=dtype, compressor=compressor)

indptr_counts = np.zeros(num_rows, dtype=int)
indptr_counts = np.zeros(num_rows, dtype=np.int64)

# Get the LD boundaries from the Zarr array attributes:
ld_boundaries = np.array(ragged_zarr.attrs['LD boundaries'])
Expand Down Expand Up @@ -444,9 +446,9 @@ def from_ragged_zarr_matrix(cls,
total_len += chunk_len

# Get the final indptr by computing cumulative sum:
indptr = np.insert(np.cumsum(indptr_counts), 0, 0)
indptr = np.insert(np.cumsum(indptr_counts, dtype=np.int64), 0, 0)
# Store indptr in the zarr array:
mat.array('indptr', indptr, compressor=compressor)
mat.array('indptr', indptr, dtype=np.int64, compressor=compressor)

# Resize the data and indices arrays:
mat['data'].resize(total_len)
Expand Down Expand Up @@ -805,7 +807,8 @@ def filter_long_range_ld_regions(self):
Which is based on the work of
> Anderson, Carl A., et al. "Data quality control in genetic case-control association studies." Nature protocols 5.9 (2010): 1564-1573.
> Anderson, Carl A., et al. "Data quality control in genetic case-control association studies."
Nature protocols 5.9 (2010): 1564-1573.
.. note ::
This method is experimental and may not work as expected for all LD matrices.
Expand Down Expand Up @@ -1502,9 +1505,10 @@ def validate_ld_matrix(self):
Specifically, we check that:
* The dimensions of the matrix and its associated attributes are matching.
* The masking is working properly.
* Index pointer is valid and its contents make sense.
:return: True if the matrix has the correct structure.
:raises ValueError: if the matrix is not valid.
:raises ValueError: If the matrix or some of its entries are not valid.
"""

class_attrs = ['snps', 'a1', 'a2', 'maf', 'bp_position', 'cm_position', 'ld_score']
Expand All @@ -1516,6 +1520,23 @@ def validate_ld_matrix(self):
if len(attribute) != len(self):
raise ValueError(f"Invalid LD Matrix: Dimensions for attribute {attr} are not aligned!")

# -------------------- Index pointer checks --------------------
# Check that the entries of the index pointer are all positive or zero:
indptr = self.indptr[:]

if indptr.min() < 0:
raise ValueError("The index pointer contains negative entries!")

# Check that the entries don't decrease:
indptr_diff = np.diff(indptr)
if indptr_diff.min() < 0:
raise ValueError("The index pointer entries are not increasing!")

# Check that the last entry of the index pointer matches the shape of the data:
if indptr[-1] != self.data.shape[0]:
raise ValueError("The last entry of the index pointer "
"does not match the shape of the data!")

# TODO: Add other sanity checks here?

return True
Expand Down
3 changes: 2 additions & 1 deletion magenpy/stats/ld/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def compute_ld_plink1p9(genotype_matrix,
# Compute the pandas chunk_size
# The goal of this is to process chunks of the LD table without overwhelming memory resources:
avg_ncols = int((ld_boundaries[1, :] - ld_boundaries[0, :]).mean())
rows_per_chunk = estimate_rows_per_chunk(ld_boundaries.shape[1], avg_ncols, dtype=dtype)
# NOTE: Estimate the rows per chunk using float32 because that's how we'll read the data from plink:
rows_per_chunk = estimate_rows_per_chunk(ld_boundaries.shape[1], avg_ncols, dtype=np.float32)

if rows_per_chunk > 0.1*ld_boundaries.shape[1]:
pandas_chunksize = None
Expand Down
13 changes: 3 additions & 10 deletions magenpy/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,14 @@ def quantize(floats, int_dtype=np.int8):
# Infer the boundaries from the integer type
info = np.iinfo(int_dtype)

# Compute the scale and zero point
# NOTE: We add 1 to the info.min here to force the zero point to be exactly at 0.
# See discussions on Scale Quantization Mapping.
scale = 2. / (info.max - (info.min + 1))

# Use as much in-place operations as possible
# (Currently, we copy the data twice)
scaled_floats = floats / scale
scaled_floats = floats*info.max
np.round(scaled_floats, out=scaled_floats)
np.clip(scaled_floats, info.min, info.max, out=scaled_floats)
np.clip(scaled_floats, info.min + 1, info.max, out=scaled_floats)

return scaled_floats.astype(int_dtype)

Expand All @@ -394,13 +392,8 @@ def dequantize(ints, float_dtype=np.float32):
# Infer the boundaries from the integer type
info = np.iinfo(ints.dtype)

# Compute the scale and zero point
# NOTE: We add 1 to the info.min here to force the zero point to be exactly at 0.
# See discussions on Scale Quantization Mapping.
scale = 2. / (info.max - (info.min + 1))

dq = ints.astype(float_dtype)
dq *= scale # in-place multiplication
dq /= info.max # in-place multiplication

return dq

Expand Down

0 comments on commit 63060d9

Please sign in to comment.