Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion libm/src/math/support/int_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,14 @@ int_impl!(i128, u128);

/// Trait for integers twice the bit width of another integer. This is implemented for all
/// primitives except for `u8`, because there is not a smaller primitive.
pub trait DInt: MinInt {
pub trait DInt:
MinInt
+ ops::Add<Output = Self>
+ ops::Sub<Output = Self>
+ ops::Shl<u32, Output = Self>
+ ops::Shr<u32, Output = Self>
+ Ord
{
/// Integer that is half the bit width of the integer this trait is implemented for
type H: HInt<D = Self>;

Expand Down
1 change: 0 additions & 1 deletion libm/src/math/support/int_traits/narrowing_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
/// This is the inverse of widening multiplication:
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
#[allow(dead_code)]
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
/// Computes `(self / n, self % n))`
///
Expand Down
3 changes: 3 additions & 0 deletions libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
mod float_traits;
pub mod hex_float;
mod int_traits;
mod modular;

#[allow(unused_imports)]
pub use big::{i256, u256};
Expand All @@ -30,6 +31,8 @@ pub use hex_float::hf128;
pub use hex_float::{hf32, hf64};
#[allow(unused_imports)]
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
#[allow(unused_imports)]
pub use modular::linear_mul_reduction;

/// Hint to the compiler that the current path is cold.
pub fn cold_path() {
Expand Down
251 changes: 251 additions & 0 deletions libm/src/math/support/modular.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this commit should be able to allow the #[allow(dead_code)] on NarrowingDiv

Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */

//! To keep the equations somewhat concise, the following conventions are used:
//! - all integer operations are in the mathematical sense, without overflow
//! - concatenation means multiplication: `2xq = 2 * x * q`
//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`

use crate::support::int_traits::NarrowingDiv;
use crate::support::{DInt, HInt, Int};
Comment on lines +8 to +9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a module-level doc comment with some of the common names used here? E.g. R, RR (if that's not R*R) r, m, q, xq. I'm unfortunately a bit lost :) (but I don't need to understand it in detail)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you've authored this, add /* SPDX-License-Identifier: MIT OR Apache-2.0 */ as well (or only MIT if it's derived, as appropriate)


/// Compute the remainder `(x << e) % y` with unbounded integers.
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
#[allow(dead_code)]
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
where
U: HInt + Int<Unsigned = U>,
U::D: NarrowingDiv,
{
assert!(y <= U::MAX >> 2);
assert!(x < (y << 1));
let _0 = U::ZERO;
let _1 = U::ONE;

// power of two divisors
if (y & (y - _1)).is_zero() {
if e < U::BITS {
// shift and only keep low bits
return (x << e) & (y - _1);
} else {
// would shift out all the bits
return _0;
}
}

// Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
// to shift the divisor so it has exactly two leading zeros to satisfy
// the precondition of `Reducer::new`
let s = y.leading_zeros() - 2;
e += s;
y <<= s;

// `m: Reducer` keeps track of the remainder `x` in a form that makes it
// very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
let mut m = Reducer::new(x, y);

// Use the faster special case with constant `k == U::BITS - 1` while we can
while e >= U::BITS - 1 {
m.word_reduce();
e -= U::BITS - 1;
}
// Finish with the variable shift operation
m.shift_reduce(e);

// The partial remainder is in `[0, 2y)` ...
let r = m.partial_remainder();
// ... so check and correct, and compensate for the earlier shift.
r.checked_sub(y).unwrap_or(r) >> s
Comment on lines +54 to +57
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That minor tweak:
Previously the right shift was done first, and then the checked_sub was with the original y. Now that's done first with the left-shifted y, and the shift is last.

}

/// Helper type for computing the reductions. The implementation has a number
/// of seemingly weird choices, but everything is aimed at streamlining
/// `Reducer::word_reduce` into its current form.
///
/// Implicitly contains:
/// n in (R/8, R/4)
/// x in [0, 2n)
/// The value of `n` is fixed for a given `Reducer`,
/// but the value of `x` is modified by the methods.
#[derive(Debug, Clone, PartialEq, Eq)]
struct Reducer<U: HInt> {
// m = 2n
m: U,
// q = (RR/2) / m
// r = (RR/2) % m
// Then RR/2 = qm + r, where `0 <= r < m`
// The value `q` is only needed during construction, so isn't saved.
r: U,
// The value `x` is implicitly stored as `2 * q * x`:
_2xq: U::D,
}

impl<U> Reducer<U>
where
U: HInt,
U: Int<Unsigned = U>,
{
/// Construct a reducer for `(x << _) mod n`.
///
/// Requires `R/8 < n < R/4` and `x < 2n`.
fn new(x: U, n: U) -> Self
where
U::D: NarrowingDiv,
{
let _1 = U::ONE;
assert!(n > (_1 << (U::BITS - 3)));
assert!(n < (_1 << (U::BITS - 2)));
let m = n << 1;
assert!(x < m);

// We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
// As R/4 < m < R/2,
// we have R <= q < 2R
// so let q = R + f
// RR/2 = (R + f)m + r
// R(R/2 - m) = fm + r

// v = R/2 - m < R/4 < m
let v = (_1 << (U::BITS - 1)) - m;
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();

// xq < qm <= RR/2
// 2xq < RR
// 2xq = 2xR + 2xf;
let _2x: U = x << 1;
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
Self { m, r, _2xq }
}

/// Extract the current remainder in the range `[0, 2n)`
fn partial_remainder(&self) -> U {
// RR/2 = qm + r, 0 <= r < m
// 2xq = uR + v, 0 <= v < R
// muR = 2mxq - mv
// = xRR - 2xr - mv
// mu + (2xr + mv)/R == xR

// 0 <= 2xq < RR
// R <= q < 2R
// 0 <= x < R/2
// R/4 < m < R/2
// 0 <= r < m
// 0 <= mv < mR
// 0 <= 2xr < rR < mR

// 0 <= (2xr + mv)/R < 2m
// Add `mu` to each term to obtain:
// mu <= xR < mu + 2m

// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
// `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
let _1 = U::ONE;
self.m.widen_mul(self._2xq.hi() + (_1 + _1)).hi()
}

/// Replace the remainder `x` with `(x << k) - un`,
/// for a suitable quotient `u`, which is returned.
fn shift_reduce(&mut self, k: u32) -> U {
assert!(k < U::BITS);
// 2xq << k = aRR/2 + b;
let a = self._2xq.hi() >> (U::BITS - 1 - k);
let (low, high) = (self._2xq << k).lo_hi();
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));

// (2xq << k) - aqm
// = aRR/2 + b - aqm
// = a(RR/2 - qm) + b
// = ar + b
self._2xq = a.widen_mul(self.r) + b;
a
}

/// Replace the remainder `x` with `x(R/2) - un`,
/// for a suitable quotient `u`, which is returned.
fn word_reduce(&mut self) -> U {
// 2xq = uR + v
let (v, u) = self._2xq.lo_hi();
// xqR - uqm
// = uRR/2 + vR/2 - uRR/2 + ur
// = ur + (v/2)R
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
u
}
}

#[cfg(test)]
mod test {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this spot check another integer size as well? Just using constants

use crate::support::linear_mul_reduction;
use crate::support::modular::Reducer;

#[test]
fn reducer_ops() {
for n in 33..=63_u8 {
for x in 0..2 * n {
let temp = Reducer::new(x, n);
let n = n as u32;
let x0 = temp.partial_remainder() as u32;
assert_eq!(x as u32, x0);
for k in 0..=7 {
let mut red = temp.clone();
let u = red.shift_reduce(k) as u32;
let x1 = red.partial_remainder() as u32;
assert_eq!(x1, (x0 << k) - u * n);
assert!(x1 < 2 * n);
assert!((red._2xq as u32).is_multiple_of(2 * x1));

// `word_reduce` is equivalent to
// `shift_reduce(U::BITS - 1)`
if k == 7 {
let mut alt = temp.clone();
let w = alt.word_reduce();
assert_eq!(u, w as u32);
assert_eq!(alt, red);
}
}
}
}
}
#[test]
fn reduction_u8() {
for y in 1..64u8 {
for x in 0..2 * y {
let mut r = x % y;
for e in 0..100 {
assert_eq!(r, linear_mul_reduction(x, e, y));
// maintain the correct expected remainder
r <<= 1;
if r >= y {
r -= y;
}
}
}
}
}
#[test]
fn reduction_u128() {
assert_eq!(
linear_mul_reduction::<u128>(17, 100, 123456789),
(17 << 100) % 123456789
);

// power-of-two divisor
assert_eq!(
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
0xbeef << 100
);

let x = 10_u128.pow(37);
let y = 11_u128.pow(36);
assert!(x < y);
let mut r = x;
for e in 0..1000 {
assert_eq!(r, linear_mul_reduction(x, e, y));
// maintain the correct expected remainder
r <<= 1;
if r >= y {
r -= y;
}
assert!(r != 0);
}
}
}
Loading