-
Notifications
You must be signed in to change notification settings - Fork 246
libm: implement accelerated computation of (x << e) % y
#1012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
/* SPDX-License-Identifier: MIT OR Apache-2.0 */ | ||
|
||
//! To keep the equations somewhat concise, the following conventions are used: | ||
//! - all integer operations are in the mathematical sense, without overflow | ||
//! - concatenation means multiplication: `2xq = 2 * x * q` | ||
//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U` | ||
|
||
use crate::support::int_traits::NarrowingDiv; | ||
use crate::support::{DInt, HInt, Int}; | ||
Comment on lines
+8
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a module-level doc comment with some of the common names used here? E.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you've authored this, add |
||
|
||
/// Compute the remainder `(x << e) % y` with unbounded integers. | ||
/// Requires `x < 2y` and `y.leading_zeros() >= 2` | ||
#[allow(dead_code)] | ||
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U | ||
where | ||
U: HInt + Int<Unsigned = U>, | ||
U::D: NarrowingDiv, | ||
{ | ||
assert!(y <= U::MAX >> 2); | ||
assert!(x < (y << 1)); | ||
let _0 = U::ZERO; | ||
let _1 = U::ONE; | ||
|
||
// power of two divisors | ||
if (y & (y - _1)).is_zero() { | ||
if e < U::BITS { | ||
// shift and only keep low bits | ||
return (x << e) & (y - _1); | ||
} else { | ||
// would shift out all the bits | ||
return _0; | ||
} | ||
} | ||
|
||
// Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s` | ||
// to shift the divisor so it has exactly two leading zeros to satisfy | ||
// the precondition of `Reducer::new` | ||
let s = y.leading_zeros() - 2; | ||
e += s; | ||
y <<= s; | ||
|
||
// `m: Reducer` keeps track of the remainder `x` in a form that makes it | ||
// very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS` | ||
let mut m = Reducer::new(x, y); | ||
|
||
// Use the faster special case with constant `k == U::BITS - 1` while we can | ||
while e >= U::BITS - 1 { | ||
m.word_reduce(); | ||
e -= U::BITS - 1; | ||
} | ||
// Finish with the variable shift operation | ||
m.shift_reduce(e); | ||
|
||
// The partial remainder is in `[0, 2y)` ... | ||
let r = m.partial_remainder(); | ||
// ... so check and correct, and compensate for the earlier shift. | ||
r.checked_sub(y).unwrap_or(r) >> s | ||
Comment on lines
+54
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That minor tweak: |
||
} | ||
|
||
/// Helper type for computing the reductions. The implementation has a number | ||
/// of seemingly weird choices, but everything is aimed at streamlining | ||
/// `Reducer::word_reduce` into its current form. | ||
/// | ||
/// Implicitly contains: | ||
/// n in (R/8, R/4) | ||
/// x in [0, 2n) | ||
/// The value of `n` is fixed for a given `Reducer`, | ||
/// but the value of `x` is modified by the methods. | ||
#[derive(Debug, Clone, PartialEq, Eq)] | ||
struct Reducer<U: HInt> { | ||
// m = 2n | ||
m: U, | ||
// q = (RR/2) / m | ||
// r = (RR/2) % m | ||
// Then RR/2 = qm + r, where `0 <= r < m` | ||
// The value `q` is only needed during construction, so isn't saved. | ||
r: U, | ||
// The value `x` is implicitly stored as `2 * q * x`: | ||
_2xq: U::D, | ||
} | ||
|
||
impl<U> Reducer<U> | ||
where | ||
U: HInt, | ||
U: Int<Unsigned = U>, | ||
{ | ||
/// Construct a reducer for `(x << _) mod n`. | ||
/// | ||
/// Requires `R/8 < n < R/4` and `x < 2n`. | ||
fn new(x: U, n: U) -> Self | ||
where | ||
U::D: NarrowingDiv, | ||
{ | ||
let _1 = U::ONE; | ||
assert!(n > (_1 << (U::BITS - 3))); | ||
assert!(n < (_1 << (U::BITS - 2))); | ||
let m = n << 1; | ||
assert!(x < m); | ||
|
||
// We need q and r s.t. RR/2 = qm + r, and `0 <= r < m` | ||
// As R/4 < m < R/2, | ||
// we have R <= q < 2R | ||
// so let q = R + f | ||
// RR/2 = (R + f)m + r | ||
// R(R/2 - m) = fm + r | ||
|
||
// v = R/2 - m < R/4 < m | ||
let v = (_1 << (U::BITS - 1)) - m; | ||
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap(); | ||
|
||
// xq < qm <= RR/2 | ||
// 2xq < RR | ||
// 2xq = 2xR + 2xf; | ||
let _2x: U = x << 1; | ||
let _2xq = _2x.widen_hi() + _2x.widen_mul(f); | ||
Self { m, r, _2xq } | ||
} | ||
|
||
/// Extract the current remainder in the range `[0, 2n)` | ||
fn partial_remainder(&self) -> U { | ||
// RR/2 = qm + r, 0 <= r < m | ||
// 2xq = uR + v, 0 <= v < R | ||
// muR = 2mxq - mv | ||
// = xRR - 2xr - mv | ||
// mu + (2xr + mv)/R == xR | ||
|
||
// 0 <= 2xq < RR | ||
// R <= q < 2R | ||
// 0 <= x < R/2 | ||
// R/4 < m < R/2 | ||
// 0 <= r < m | ||
// 0 <= mv < mR | ||
// 0 <= 2xr < rR < mR | ||
|
||
// 0 <= (2xr + mv)/R < 2m | ||
// Add `mu` to each term to obtain: | ||
// mu <= xR < mu + 2m | ||
|
||
// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between | ||
// `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`. | ||
let _1 = U::ONE; | ||
self.m.widen_mul(self._2xq.hi() + (_1 + _1)).hi() | ||
} | ||
|
||
/// Replace the remainder `x` with `(x << k) - un`, | ||
/// for a suitable quotient `u`, which is returned. | ||
fn shift_reduce(&mut self, k: u32) -> U { | ||
assert!(k < U::BITS); | ||
// 2xq << k = aRR/2 + b; | ||
let a = self._2xq.hi() >> (U::BITS - 1 - k); | ||
let (low, high) = (self._2xq << k).lo_hi(); | ||
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1)); | ||
|
||
// (2xq << k) - aqm | ||
// = aRR/2 + b - aqm | ||
// = a(RR/2 - qm) + b | ||
// = ar + b | ||
self._2xq = a.widen_mul(self.r) + b; | ||
a | ||
} | ||
|
||
/// Replace the remainder `x` with `x(R/2) - un`, | ||
/// for a suitable quotient `u`, which is returned. | ||
fn word_reduce(&mut self) -> U { | ||
// 2xq = uR + v | ||
let (v, u) = self._2xq.lo_hi(); | ||
// xqR - uqm | ||
// = uRR/2 + vR/2 - uRR/2 + ur | ||
// = ur + (v/2)R | ||
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1); | ||
u | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this spot check another integer size as well? Just using constants |
||
use crate::support::linear_mul_reduction; | ||
use crate::support::modular::Reducer; | ||
|
||
#[test] | ||
fn reducer_ops() { | ||
for n in 33..=63_u8 { | ||
for x in 0..2 * n { | ||
let temp = Reducer::new(x, n); | ||
let n = n as u32; | ||
let x0 = temp.partial_remainder() as u32; | ||
assert_eq!(x as u32, x0); | ||
for k in 0..=7 { | ||
let mut red = temp.clone(); | ||
let u = red.shift_reduce(k) as u32; | ||
let x1 = red.partial_remainder() as u32; | ||
assert_eq!(x1, (x0 << k) - u * n); | ||
assert!(x1 < 2 * n); | ||
assert!((red._2xq as u32).is_multiple_of(2 * x1)); | ||
|
||
// `word_reduce` is equivalent to | ||
// `shift_reduce(U::BITS - 1)` | ||
if k == 7 { | ||
let mut alt = temp.clone(); | ||
let w = alt.word_reduce(); | ||
assert_eq!(u, w as u32); | ||
assert_eq!(alt, red); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
#[test] | ||
fn reduction_u8() { | ||
for y in 1..64u8 { | ||
for x in 0..2 * y { | ||
let mut r = x % y; | ||
for e in 0..100 { | ||
assert_eq!(r, linear_mul_reduction(x, e, y)); | ||
// maintain the correct expected remainder | ||
r <<= 1; | ||
if r >= y { | ||
r -= y; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
#[test] | ||
fn reduction_u128() { | ||
assert_eq!( | ||
linear_mul_reduction::<u128>(17, 100, 123456789), | ||
(17 << 100) % 123456789 | ||
); | ||
|
||
// power-of-two divisor | ||
assert_eq!( | ||
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116), | ||
0xbeef << 100 | ||
); | ||
|
||
let x = 10_u128.pow(37); | ||
let y = 11_u128.pow(36); | ||
assert!(x < y); | ||
let mut r = x; | ||
for e in 0..1000 { | ||
assert_eq!(r, linear_mul_reduction(x, e, y)); | ||
// maintain the correct expected remainder | ||
r <<= 1; | ||
if r >= y { | ||
r -= y; | ||
} | ||
assert!(r != 0); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this commit should be able to allow the
#[allow(dead_code)]
onNarrowingDiv