package main
import (
"fmt"
"math"
"strconv"
"time"
"github.com/ldsec/lattigo/v2/ckks"
"github.com/ldsec/lattigo/v2/rlwe"
)
func InspectPlain(
name string,
params *ckks.Parameters,
encoder ckks.Encoder,
encryptor ckks.Encryptor,
decryptor ckks.Decryptor,
plaintext *ckks.Plaintext,
) {
values := encoder.Decode(plaintext, params.LogSlots())
logScale := fmt.Sprintf("%.1f", math.Log2(plaintext.Scale))
fmt.Printf("\n%s\n", name)
fmt.Printf("level = %d, scale = %v bit\n", plaintext.Level(), logScale)
fmt.Printf("values=%6.5f %6.5f %6.5f %6.5f...\n", values[0], values[1], values[2], values[3])
}
func Inspect(
name string,
params *ckks.Parameters,
encoder ckks.Encoder,
encryptor ckks.Encryptor,
decryptor ckks.Decryptor,
ciphertext *ckks.Ciphertext,
) {
InspectPlain(name, params, encoder, encryptor, decryptor, decryptor.DecryptNew(ciphertext))
}
func Pow2(x int) int {
return 1 << uint(x)
}
func RotateLeftNew(
params *ckks.Parameters,
evaluator ckks.Evaluator,
ciphertext *ckks.Ciphertext,
amount int,
) *ckks.Ciphertext {
// needs rot keys in po2
if amount < 0 {
amount = params.Slots() + amount
}
res := DecompPo2(amount)
// fmt.Printf("Decomposed: %v\n", res)
var out *ckks.Ciphertext
out = ciphertext
for _, r := range res {
evaluator.Rotate(out, r, out)
}
return out
}
func DecompPo2(amount int) []int {
res := make([]int, 0)
l := int(math.Floor(math.Log2(float64(amount))))
for i, c := range strconv.FormatInt(int64(amount), 2) {
if c == '1' {
res = append(res, Pow2(l-i))
}
}
return res
}
func Kernel(
params *ckks.Parameters,
encoder ckks.Encoder,
encryptor ckks.Encryptor,
decryptor ckks.Decryptor,
evaluator ckks.Evaluator,
node1 *ckks.Ciphertext,
) *ckks.Ciphertext {
start := time.Now()
coeffs := []complex128{
complex(1.0, 0),
complex(1.0, 0),
complex(1.0/2, 0),
complex(1.0/6, 0),
complex(1.0/24, 0),
complex(1.0/120, 0),
complex(1.0/720, 0),
complex(1.0/5040, 0),
}
poly := ckks.NewPoly(coeffs)
eval, _ := evaluator.EvaluatePoly(node1, poly, node1.Scale)
elapsed := time.Since(start).Milliseconds()
// fmt.Printf("time: %v ms\n", elapsed)
fmt.Printf("%v\n", elapsed)
return eval
}
func main() {
var err error
params, err := ckks.NewParametersFromLiteral(ckks.PN13QP218)
if err != nil {
panic(err)
}
// fmt.Printf("params\nLevel = %d, logQ = %d, Scale = %f bit\n", params.MaxLevel(), params.LogQ(), math.Log2(params.Scale()))
encoder := ckks.NewEncoder(params)
kgen := ckks.NewKeyGenerator(params)
sk, pk := kgen.GenKeyPair()
rlk := kgen.GenRelinearizationKey(sk, 2)
indices := make([]int, 0)
for _, r := range DecompPo2(params.Slots() - 1) {
indices = append(indices, r)
}
// fmt.Printf("Indices: %v\n", indices)
rtks := kgen.GenRotationKeysForRotations(indices, false, sk)
// fmt.Printf("Rotation keys: %v\n", rtks)
// k, _ := rtks.GetRotationKey(1)
// fmt.Printf("Rotation key: %v\n", k)
encryptor := ckks.NewEncryptor(params, pk)
decryptor := ckks.NewDecryptor(params, sk)
evaluator := ckks.NewEvaluator(params, rlwe.EvaluationKey{Rlk: rlk, Rtks: rtks})
// var out *ckks.Ciphertext
v1 := make([]complex128, params.Slots())
p1 := encoder.EncodeNew(v1, params.MaxLevel(), params.DefaultScale(), params.LogSlots())
node1 := encryptor.EncryptNew(p1)
Kernel(¶ms, encoder, encryptor, decryptor, evaluator, node1)
// Inspect("out", ¶ms, encoder, encryptor, decryptor, out)
}