CKKS-RNS 完全理解

📝 Background

https://eprint.iacr.org/2018/931.pdf

Ciphertext.Mult(mk, c1, c2)

# c': RNSPoly
# q: c1.log_cap
# Q: params.log_lo_modulus
# P: params.log_hi_modulus
c1a, c1b, c2a, c2b, mka', mkb' = c1, c2, mk
# with q
r1a', r1b', r2a', r2b' = to(c1a, c1b, c2a, c2b)
# with q
axax, bxbx = from(c1a' * c2b'), from(c1b' * c2a')
ra1', ra2 = ra1' + rb1', ra2' + rb1'
axbx = from(ra1' * ra2') # with q
#=
In the paper, we want to find `(P^(-1) * x * y) mod q`,
where `P = 2^log_hi_modulus`, `Q = 2^log_lo_modulus`, `x` is a polynomial modulo `q <= Q`,
and `y` is a polynomial modulo `P * Q`.
 
So, we perform the multiplication using the full range `q * Q * P`
(times the polynomial length to prevent overflow during NTT in RNS),
then drop the high `Q` bits (during the conversion from RNS to big integer),
then shift by `P`.
=#
raa = to(axax) # P + Q
 
# >>(x, shift) = (signbit(x) ? -1 : 1) * (abs(x + 2^(shift-1)) >> shift)
ax, bx = from(axax * mka) >> P, from(axax * mkb) >> P # with q + P
ax, bx = ax + axbx - bxbx - axax, bx + bxbx
return (ax, bx) # prec = c1.prec + c2.prec

Ciphertext.CMult(c1, p)

# q: c1.log_cap
c1a, c1b = c1
# with q
p', c1a', c1b' = to(p, c1a, c1b)
ca, cb = from(c1a' * p'), from(c1b' * p')
return (ca, cb) # prec = c1.prec + prec

Convert RNS

function _log_modulus_mul(log_modulus1::Int, log_modulus2::Int, polynomial_length::Int)
    # The result is the product of polynomials, so the maximum number we can encounter
    # is the polynomial-length sum of products of maximum numbers from `x` and `y`
    # (perhaps that's even a bit too conservative for negacyclic polynomials).
    log_plen = num_bits(polynomial_length) - 1
 
    # `-1` is because each coefficient <= 2^(log_modulus-1)
    # So the maximum is <= 2^(log_modulus1-1) * 2^(log_modulus2-1) * 2^log_plen
    # `+1` because we need to fit both positive and negative numbers of that range.
    (log_modulus1 - 1) + (log_modulus2 - 1) + log_plen + 1
end
 
function to_rns_transformed(plan::RNSPlan, x::Polynomial{BinModuloInt{T, Q}}, add_range::Int=0)::RNSPolynomialTransformed where {T, Q}
    # The intention is to have enough range for one multiplication of this polynomial
    # and another one with `log_modulus <= add_range`.
    log_range = _log_modulus_mul(Q, add_range, length(x.coeffs))
 
    # Q が log_modulus となる
    # require Q <= log_range
    rns = _to_rns(plan, x, log_range)
 
    _ntt_forward(rns)
end
 
function _to_rns(plan::RNSPlan, x::Polynomial{BinModuloInt{T, Q}}, log_range::Int) where {T, Q}
    # @show log_range
    # println("Q(x.log_modulus) = $Q <=? log_range = $log_range")
    np = min_nprimes(plan, log_range)
    log_range = max_log_modulus(plan, np)
 
    plen = length(x.coeffs)
    res = Array{UInt64}(undef, plen, np)
    for i in 1:plen
        res[i,:] .= to_rns_signed(plan, x.coeffs[i], np)
    end
    RNSPolynomial(plan, res, Q, log_range, x.modulus)
end
 
function from_rns_transformed(x::RNSPolynomialTransformed, log_modulus::Int=0)::Polynomial{BinModuloInt{BigInt, Q}}
    _from_rns(Polynomial{BinModuloInt{BigInt, log_modulus}}, _ntt_inverse(x))
end
 
function _from_rns(::Type{Polynomial{BinModuloInt{T, Q}}}, x::RNSPolynomial) where {T, Q}
    @assert Q <= x.log_range
    plan = x.plan
    plen = size(x.residuals, 1)
    res = Array{BinModuloInt{T, Q}}(undef, plen)
    for i in 1:plen
        res[i] = from_rns_signed(plan, BinModuloInt{T, Q}, x.residuals[i,:])
    end
    Polynomial(res, x.polynomial_modulus)
end
 
# BinModuloInt{T, Q} == T % 2^Q
# The numbers stored are guaranteed to fit into `log_modulus` bits,
# that is they lie in `(P - 2^(log_modulus-1) + 1, P - 1] U [0, 2^(log_modulus-1)]`,
# where `P = prod(plan.primes[1:np])`, and `np` is the second dimension of `residuals`
# (and, of course, `P > 2^log_modulus - 1`, so these intervals don't intersect).
log_modulus :: Int
 
# The maximum possible range for stored numbers for the given number of residuals.
# `log_modulus` can't get greater than this.
# `log_range` has a one-to-one correspondence with the number of primes used.
log_range :: Int

RNSPolynomialTransformed Operation

RNSPolynomialTransformed.Mult

# x, y: (N, r, log_range, log_modulus) = (1024, 22, 1385, 689)
# new_log_modulus = 689 + 689 + 1024 - 1 = 1387 <= 1385
function rns_mult(x, y)
    @assert x.polynomial_modulus == y.polynomial_modulus
 
    plen = size(x.residuals, 1)
    new_log_modulus = _log_modulus_mul(x.log_modulus, y.log_modulus, plen)
    new_log_range = min(x.log_range, y.log_range)
    @assert new_log_modulus <= new_log_range
 
    plan = x.plan
    np = min(size(x.residuals, 2), size(y.residuals, 2))
    plen = size(x.residuals, 1)
    plen_rt = round(Int64, sqrt(plen))
    res = Array{UInt64}(undef, size(x.residuals, 1), np)
 
    for j in 1:np
        res[:,j] = mulmod.(x.residuals[:,j], y.residuals[:,j], plan.primes[j])
    end
 
    RNSPolynomialTransformed(plan, res, new_log_modulus, new_log_range, x.polynomial_modulus)
end