258 lines
6.1 KiB
Rust
258 lines
6.1 KiB
Rust
use super::monty::monty_modpow;
|
|
use super::BigUint;
|
|
|
|
use crate::big_digit::{self, BigDigit};
|
|
|
|
use num_integer::Integer;
|
|
use num_traits::{One, Pow, ToPrimitive, Zero};
|
|
|
|
impl Pow<&BigUint> for BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: &BigUint) -> BigUint {
|
|
if self.is_one() || exp.is_zero() {
|
|
BigUint::one()
|
|
} else if self.is_zero() {
|
|
BigUint::zero()
|
|
} else if let Some(exp) = exp.to_u64() {
|
|
self.pow(exp)
|
|
} else if let Some(exp) = exp.to_u128() {
|
|
self.pow(exp)
|
|
} else {
|
|
// At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given
|
|
// `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address!
|
|
panic!("memory overflow")
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Pow<BigUint> for BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: BigUint) -> BigUint {
|
|
Pow::pow(self, &exp)
|
|
}
|
|
}
|
|
|
|
impl Pow<&BigUint> for &BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: &BigUint) -> BigUint {
|
|
if self.is_one() || exp.is_zero() {
|
|
BigUint::one()
|
|
} else if self.is_zero() {
|
|
BigUint::zero()
|
|
} else {
|
|
self.clone().pow(exp)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Pow<BigUint> for &BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: BigUint) -> BigUint {
|
|
Pow::pow(self, &exp)
|
|
}
|
|
}
|
|
|
|
macro_rules! pow_impl {
|
|
($T:ty) => {
|
|
impl Pow<$T> for BigUint {
|
|
type Output = BigUint;
|
|
|
|
fn pow(self, mut exp: $T) -> BigUint {
|
|
if exp == 0 {
|
|
return BigUint::one();
|
|
}
|
|
let mut base = self;
|
|
|
|
while exp & 1 == 0 {
|
|
base = &base * &base;
|
|
exp >>= 1;
|
|
}
|
|
|
|
if exp == 1 {
|
|
return base;
|
|
}
|
|
|
|
let mut acc = base.clone();
|
|
while exp > 1 {
|
|
exp >>= 1;
|
|
base = &base * &base;
|
|
if exp & 1 == 1 {
|
|
acc *= &base;
|
|
}
|
|
}
|
|
acc
|
|
}
|
|
}
|
|
|
|
impl Pow<&$T> for BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: &$T) -> BigUint {
|
|
Pow::pow(self, *exp)
|
|
}
|
|
}
|
|
|
|
impl Pow<$T> for &BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: $T) -> BigUint {
|
|
if exp == 0 {
|
|
return BigUint::one();
|
|
}
|
|
Pow::pow(self.clone(), exp)
|
|
}
|
|
}
|
|
|
|
impl Pow<&$T> for &BigUint {
|
|
type Output = BigUint;
|
|
|
|
#[inline]
|
|
fn pow(self, exp: &$T) -> BigUint {
|
|
Pow::pow(self, *exp)
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
pow_impl!(u8);
|
|
pow_impl!(u16);
|
|
pow_impl!(u32);
|
|
pow_impl!(u64);
|
|
pow_impl!(usize);
|
|
pow_impl!(u128);
|
|
|
|
pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
|
|
assert!(
|
|
!modulus.is_zero(),
|
|
"attempt to calculate with zero modulus!"
|
|
);
|
|
|
|
if modulus.is_odd() {
|
|
// For an odd modulus, we can use Montgomery multiplication in base 2^32.
|
|
monty_modpow(x, exponent, modulus)
|
|
} else {
|
|
// Otherwise do basically the same as `num::pow`, but with a modulus.
|
|
plain_modpow(x, &exponent.data, modulus)
|
|
}
|
|
}
|
|
|
|
fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {
|
|
assert!(
|
|
!modulus.is_zero(),
|
|
"attempt to calculate with zero modulus!"
|
|
);
|
|
|
|
let i = match exp_data.iter().position(|&r| r != 0) {
|
|
None => return BigUint::one(),
|
|
Some(i) => i,
|
|
};
|
|
|
|
let mut base = base % modulus;
|
|
for _ in 0..i {
|
|
for _ in 0..big_digit::BITS {
|
|
base = &base * &base % modulus;
|
|
}
|
|
}
|
|
|
|
let mut r = exp_data[i];
|
|
let mut b = 0u8;
|
|
while r.is_even() {
|
|
base = &base * &base % modulus;
|
|
r >>= 1;
|
|
b += 1;
|
|
}
|
|
|
|
let mut exp_iter = exp_data[i + 1..].iter();
|
|
if exp_iter.len() == 0 && r.is_one() {
|
|
return base;
|
|
}
|
|
|
|
let mut acc = base.clone();
|
|
r >>= 1;
|
|
b += 1;
|
|
|
|
{
|
|
let mut unit = |exp_is_odd| {
|
|
base = &base * &base % modulus;
|
|
if exp_is_odd {
|
|
acc *= &base;
|
|
acc %= modulus;
|
|
}
|
|
};
|
|
|
|
if let Some(&last) = exp_iter.next_back() {
|
|
// consume exp_data[i]
|
|
for _ in b..big_digit::BITS {
|
|
unit(r.is_odd());
|
|
r >>= 1;
|
|
}
|
|
|
|
// consume all other digits before the last
|
|
for &r in exp_iter {
|
|
let mut r = r;
|
|
for _ in 0..big_digit::BITS {
|
|
unit(r.is_odd());
|
|
r >>= 1;
|
|
}
|
|
}
|
|
r = last;
|
|
}
|
|
|
|
debug_assert_ne!(r, 0);
|
|
while !r.is_zero() {
|
|
unit(r.is_odd());
|
|
r >>= 1;
|
|
}
|
|
}
|
|
acc
|
|
}
|
|
|
|
#[test]
|
|
fn test_plain_modpow() {
|
|
let two = &BigUint::from(2u32);
|
|
let modulus = BigUint::from(0x1100u32);
|
|
|
|
let exp = vec![0, 0b1];
|
|
assert_eq!(
|
|
two.pow(0b1_00000000_u32) % &modulus,
|
|
plain_modpow(two, &exp, &modulus)
|
|
);
|
|
let exp = vec![0, 0b10];
|
|
assert_eq!(
|
|
two.pow(0b10_00000000_u32) % &modulus,
|
|
plain_modpow(two, &exp, &modulus)
|
|
);
|
|
let exp = vec![0, 0b110010];
|
|
assert_eq!(
|
|
two.pow(0b110010_00000000_u32) % &modulus,
|
|
plain_modpow(two, &exp, &modulus)
|
|
);
|
|
let exp = vec![0b1, 0b1];
|
|
assert_eq!(
|
|
two.pow(0b1_00000001_u32) % &modulus,
|
|
plain_modpow(two, &exp, &modulus)
|
|
);
|
|
let exp = vec![0b1100, 0, 0b1];
|
|
assert_eq!(
|
|
two.pow(0b1_00000000_00001100_u32) % &modulus,
|
|
plain_modpow(two, &exp, &modulus)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pow_biguint() {
|
|
let base = BigUint::from(5u8);
|
|
let exponent = BigUint::from(3u8);
|
|
|
|
assert_eq!(BigUint::from(125u8), base.pow(exponent));
|
|
}
|