CKKS-RNS 完全理解
📝 Background
https://eprint.iacr.org/2018/931.pdf
Ciphertext.Mult(mk, c1, c2)
# c': RNSPoly
# q: c1.log_cap
# Q: params.log_lo_modulus
# P: params.log_hi_modulus
c1a, c1b, c2a, c2b, mka', mkb' = c1, c2, mk
# with q
r1a', r1b', r2a', r2b' = to(c1a, c1b, c2a, c2b)
# with q
axax, bxbx = from(c1a' * c2b'), from(c1b' * c2a')
ra1', ra2 = ra1' + rb1', ra2' + rb1'
axbx = from(ra1' * ra2') # with q
#=
In the paper, we want to find `(P^(-1) * x * y) mod q`,
where `P = 2^log_hi_modulus`, `Q = 2^log_lo_modulus`, `x` is a polynomial modulo `q <= Q`,
and `y` is a polynomial modulo `P * Q`.
So, we perform the multiplication using the full range `q * Q * P`
(times the polynomial length to prevent overflow during NTT in RNS),
then drop the high `Q` bits (during the conversion from RNS to big integer),
then shift by `P`.
=#
raa = to(axax) # P + Q
# >>(x, shift) = (signbit(x) ? -1 : 1) * (abs(x + 2^(shift-1)) >> shift)
ax, bx = from(axax * mka) >> P, from(axax * mkb) >> P # with q + P
ax, bx = ax + axbx - bxbx - axax, bx + bxbx
return (ax, bx) # prec = c1.prec + c2.precCiphertext.CMult(c1, p)
# q: c1.log_cap
c1a, c1b = c1
# with q
p', c1a', c1b' = to(p, c1a, c1b)
ca, cb = from(c1a' * p'), from(c1b' * p')
return (ca, cb) # prec = c1.prec + precConvert RNS
function _log_modulus_mul(log_modulus1::Int, log_modulus2::Int, polynomial_length::Int)
# The result is the product of polynomials, so the maximum number we can encounter
# is the polynomial-length sum of products of maximum numbers from `x` and `y`
# (perhaps that's even a bit too conservative for negacyclic polynomials).
log_plen = num_bits(polynomial_length) - 1
# `-1` is because each coefficient <= 2^(log_modulus-1)
# So the maximum is <= 2^(log_modulus1-1) * 2^(log_modulus2-1) * 2^log_plen
# `+1` because we need to fit both positive and negative numbers of that range.
(log_modulus1 - 1) + (log_modulus2 - 1) + log_plen + 1
end
function to_rns_transformed(plan::RNSPlan, x::Polynomial{BinModuloInt{T, Q}}, add_range::Int=0)::RNSPolynomialTransformed where {T, Q}
# The intention is to have enough range for one multiplication of this polynomial
# and another one with `log_modulus <= add_range`.
log_range = _log_modulus_mul(Q, add_range, length(x.coeffs))
# Q が log_modulus となる
# require Q <= log_range
rns = _to_rns(plan, x, log_range)
_ntt_forward(rns)
end
function _to_rns(plan::RNSPlan, x::Polynomial{BinModuloInt{T, Q}}, log_range::Int) where {T, Q}
# @show log_range
# println("Q(x.log_modulus) = $Q <=? log_range = $log_range")
np = min_nprimes(plan, log_range)
log_range = max_log_modulus(plan, np)
plen = length(x.coeffs)
res = Array{UInt64}(undef, plen, np)
for i in 1:plen
res[i,:] .= to_rns_signed(plan, x.coeffs[i], np)
end
RNSPolynomial(plan, res, Q, log_range, x.modulus)
end
function from_rns_transformed(x::RNSPolynomialTransformed, log_modulus::Int=0)::Polynomial{BinModuloInt{BigInt, Q}}
_from_rns(Polynomial{BinModuloInt{BigInt, log_modulus}}, _ntt_inverse(x))
end
function _from_rns(::Type{Polynomial{BinModuloInt{T, Q}}}, x::RNSPolynomial) where {T, Q}
@assert Q <= x.log_range
plan = x.plan
plen = size(x.residuals, 1)
res = Array{BinModuloInt{T, Q}}(undef, plen)
for i in 1:plen
res[i] = from_rns_signed(plan, BinModuloInt{T, Q}, x.residuals[i,:])
end
Polynomial(res, x.polynomial_modulus)
end
# BinModuloInt{T, Q} == T % 2^Q# The numbers stored are guaranteed to fit into `log_modulus` bits,
# that is they lie in `(P - 2^(log_modulus-1) + 1, P - 1] U [0, 2^(log_modulus-1)]`,
# where `P = prod(plan.primes[1:np])`, and `np` is the second dimension of `residuals`
# (and, of course, `P > 2^log_modulus - 1`, so these intervals don't intersect).
log_modulus :: Int
# The maximum possible range for stored numbers for the given number of residuals.
# `log_modulus` can't get greater than this.
# `log_range` has a one-to-one correspondence with the number of primes used.
log_range :: IntRNSPolynomialTransformed Operation
RNSPolynomialTransformed.Mult
# x, y: (N, r, log_range, log_modulus) = (1024, 22, 1385, 689)
# new_log_modulus = 689 + 689 + 1024 - 1 = 1387 <= 1385
function rns_mult(x, y)
@assert x.polynomial_modulus == y.polynomial_modulus
plen = size(x.residuals, 1)
new_log_modulus = _log_modulus_mul(x.log_modulus, y.log_modulus, plen)
new_log_range = min(x.log_range, y.log_range)
@assert new_log_modulus <= new_log_range
plan = x.plan
np = min(size(x.residuals, 2), size(y.residuals, 2))
plen = size(x.residuals, 1)
plen_rt = round(Int64, sqrt(plen))
res = Array{UInt64}(undef, size(x.residuals, 1), np)
for j in 1:np
res[:,j] = mulmod.(x.residuals[:,j], y.residuals[:,j], plan.primes[j])
end
RNSPolynomialTransformed(plan, res, new_log_modulus, new_log_range, x.polynomial_modulus)
end