diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 5336b0bb8..50e4d3a14 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -161,6 +161,7 @@ where // r *= 2. fn elem_double(r: &mut Elem, m: &Modulus) { limb::limbs_double_mod(&mut r.limbs, m.limbs()) + .unwrap_or_else(unwrap_impossible_len_mismatch_error) } // TODO: This is currently unused, but we intend to eventually use this to diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index 053bfeb7f..ef7a32f09 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -12,7 +12,10 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{super::montgomery::Unencoded, BoxedLimbs, Elem, OwnedModulusValue, PublicModulus, N0}; +use super::{ + super::montgomery::Unencoded, unwrap_impossible_len_mismatch_error, BoxedLimbs, Elem, + OwnedModulusValue, PublicModulus, N0, +}; use crate::{ bits::BitLength, cpu, error, @@ -164,6 +167,7 @@ impl Modulus<'_, M> { // to 2**r (mod m). for _ in 0..leading_zero_bits_in_m { limb::limbs_double_mod(out, self.limbs) + .unwrap_or_else(unwrap_impossible_len_mismatch_error); } } diff --git a/src/arithmetic/inout.rs b/src/arithmetic/inout.rs index c6f36b5b9..4dda6542f 100644 --- a/src/arithmetic/inout.rs +++ b/src/arithmetic/inout.rs @@ -15,6 +15,81 @@ pub(crate) use crate::error::LenMismatchError; use core::num::NonZeroUsize; +pub(crate) trait AliasingSlices2 { + /// The pointers passed to `f` will be valid and non-null, and will not + /// be dangling, so they can be passed to C functions. + /// + /// The first pointer, `r`, may be pointing to uninitialized memory for + /// `expected_len` elements of type `T`, properly aligned and writable. + /// `f` must not read from `r` before writing to it. + /// + /// The second & third pointers, `a` and `b`, point to `expected_len` + /// values of type `T`, properly aligned. + /// + /// `r`, `a`, and/or `b` may alias each other only in the following ways: + /// `ptr::eq(r, a)`, `ptr::eq(r, b)`, and/or `ptr::eq(a, b)`; i.e. they + /// will not be "overlapping." + /// + /// Implementations of this trait shouldn't override this default + /// implementation. + #[inline(always)] + fn with_non_dangling_non_null_pointers_ra( + self, + expected_len: NonZeroUsize, + f: impl FnOnce(*mut T, *const T) -> R, + ) -> Result + where + Self: Sized, + { + self.with_potentially_dangling_non_null_pointers_ra(expected_len.into(), f) + } + + /// If `expected_len == 0` then the pointers passed to `f` may be + /// dangling pointers, which should not be passed to C functions. In all + /// other respects, this works like + /// `Self::with_non_dangling_non_null_pointers_rab`. + /// + /// Implementations of this trait should implement this method and not + /// `with_non_dangling_non_null_pointers_rab`. Users of this trait should + /// use `with_non_dangling_non_null_pointers_rab` and not this. + fn with_potentially_dangling_non_null_pointers_ra( + self, + expected_len: usize, + f: impl FnOnce(*mut T, *const T) -> R, + ) -> Result; +} + +impl AliasingSlices2 for &mut [T] { + fn with_potentially_dangling_non_null_pointers_ra( + self, + expected_len: usize, + f: impl FnOnce(*mut T, *const T) -> R, + ) -> Result { + let r = self; + if r.len() != expected_len { + return Err(LenMismatchError::new(r.len())); + } + Ok(f(r.as_mut_ptr(), r.as_ptr())) + } +} + +impl AliasingSlices2 for (&mut [T], &[T]) { + fn with_potentially_dangling_non_null_pointers_ra( + self, + expected_len: usize, + f: impl FnOnce(*mut T, *const T) -> R, + ) -> Result { + let (r, a) = self; + if r.len() != expected_len { + return Err(LenMismatchError::new(r.len())); + } + if a.len() != expected_len { + return Err(LenMismatchError::new(a.len())); + } + Ok(f(r.as_mut_ptr(), a.as_ptr())) + } +} + pub(crate) trait AliasingSlices3 { /// The pointers passed to `f` will all be non-null and properly aligned, /// and will not be dangling. @@ -65,47 +140,38 @@ impl AliasingSlices3 for &mut [T] { expected_len: usize, f: impl FnOnce(*mut T, *const T, *const T) -> R, ) -> Result { - let r = self; - if r.len() != expected_len { - return Err(LenMismatchError::new(r.len())); - } - Ok(f(r.as_mut_ptr(), r.as_ptr(), r.as_ptr())) + >::with_potentially_dangling_non_null_pointers_ra( + self, + expected_len, + |r, a| f(r, r, a), + ) } } -impl AliasingSlices3 for (&mut [T], &[T]) { +impl AliasingSlices3 for (&mut [T], &[T], &[T]) { fn with_potentially_dangling_non_null_pointers_rab( self, expected_len: usize, f: impl FnOnce(*mut T, *const T, *const T) -> R, ) -> Result { - let (r, a) = self; - if r.len() != expected_len { - return Err(LenMismatchError::new(r.len())); - } - if a.len() != expected_len { - return Err(LenMismatchError::new(a.len())); - } - Ok(f(r.as_mut_ptr(), r.as_ptr(), a.as_ptr())) + let (r, a, b) = self; + ((r, a), b).with_potentially_dangling_non_null_pointers_rab(expected_len, f) } } -impl AliasingSlices3 for (&mut [T], &[T], &[T]) { +impl AliasingSlices3 for (RA, &[T]) +where + RA: AliasingSlices2, +{ fn with_potentially_dangling_non_null_pointers_rab( self, expected_len: usize, f: impl FnOnce(*mut T, *const T, *const T) -> R, ) -> Result { - let (r, a, b) = self; - if r.len() != expected_len { - return Err(LenMismatchError::new(r.len())); - } - if a.len() != expected_len { - return Err(LenMismatchError::new(a.len())); - } + let (ra, b) = self; if b.len() != expected_len { return Err(LenMismatchError::new(b.len())); } - Ok(f(r.as_mut_ptr(), a.as_ptr(), b.as_ptr())) + ra.with_potentially_dangling_non_null_pointers_ra(expected_len, |r, a| f(r, a, b.as_ptr())) } } diff --git a/src/limb.rs b/src/limb.rs index 2a82eb7cb..7480600e2 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -19,7 +19,7 @@ //! limbs use the native endianness. use crate::{ - arithmetic::inout::AliasingSlices3, + arithmetic::inout::{AliasingSlices2, AliasingSlices3}, c, constant_time, error::{self, LenMismatchError}, polyfill::{sliceutil, usize_from_u32, ArrayFlatMap}, @@ -349,14 +349,22 @@ pub(crate) fn limbs_add_assign_mod( } // r *= 2 (mod m). -pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) { - assert_eq!(r.len(), m.len()); +pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) -> Result<(), LenMismatchError> { prefixed_extern! { - fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t); - } - unsafe { - LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len()); + // `r` and `a` may alias. + fn LIMBS_shl_mod( + r: *mut Limb, + a: *const Limb, + m: *const Limb, + num_limbs: c::NonZero_size_t); } + let num_limbs = NonZeroUsize::new(m.len()).ok_or_else(|| LenMismatchError::new(m.len()))?; + r.with_non_dangling_non_null_pointers_ra(num_limbs, |r, a| { + let m = m.as_ptr(); // Also non-dangling because num_limbs > 0. + unsafe { + LIMBS_shl_mod(r, a, m, num_limbs); + } + }) } // *r = -a, assuming a is odd.