lattigo BSGS

package main
 
import (
	"fmt"
	"math"
	"strconv"
	"time"
 
	"github.com/ldsec/lattigo/v2/ckks"
	"github.com/ldsec/lattigo/v2/rlwe"
)
 
func InspectPlain(
	name string,
	params *ckks.Parameters,
	encoder ckks.Encoder,
	encryptor ckks.Encryptor,
	decryptor ckks.Decryptor,
	plaintext *ckks.Plaintext,
) {
	values := encoder.Decode(plaintext, params.LogSlots())
	logScale := fmt.Sprintf("%.1f", math.Log2(plaintext.Scale))
	fmt.Printf("\n%s\n", name)
	fmt.Printf("level = %d, scale = %v bit\n", plaintext.Level(), logScale)
	fmt.Printf("values=%6.5f %6.5f %6.5f %6.5f...\n", values[0], values[1], values[2], values[3])
}
 
func Inspect(
	name string,
	params *ckks.Parameters,
	encoder ckks.Encoder,
	encryptor ckks.Encryptor,
	decryptor ckks.Decryptor,
	ciphertext *ckks.Ciphertext,
) {
	InspectPlain(name, params, encoder, encryptor, decryptor, decryptor.DecryptNew(ciphertext))
}
 
func Pow2(x int) int {
	return 1 << uint(x)
}
 
func RotateLeftNew(
	params *ckks.Parameters,
	evaluator ckks.Evaluator,
	ciphertext *ckks.Ciphertext,
	amount int,
) *ckks.Ciphertext {
	// needs rot keys in po2
	if amount < 0 {
		amount = params.Slots() + amount
	}
	res := DecompPo2(amount)
	// fmt.Printf("Decomposed: %v\n", res)
 
	var out *ckks.Ciphertext
	out = ciphertext
	for _, r := range res {
		evaluator.Rotate(out, r, out)
	}
	return out
}
 
func DecompPo2(amount int) []int {
	res := make([]int, 0)
	l := int(math.Floor(math.Log2(float64(amount))))
	for i, c := range strconv.FormatInt(int64(amount), 2) {
		if c == '1' {
			res = append(res, Pow2(l-i))
		}
	}
	return res
}
 
func Kernel(
	params *ckks.Parameters,
	encoder ckks.Encoder,
	encryptor ckks.Encryptor,
	decryptor ckks.Decryptor,
	evaluator ckks.Evaluator,
	node1 *ckks.Ciphertext,
 
) *ckks.Ciphertext {
	start := time.Now()
	coeffs := []complex128{
		complex(1.0, 0),
		complex(1.0, 0),
		complex(1.0/2, 0),
		complex(1.0/6, 0),
		complex(1.0/24, 0),
		complex(1.0/120, 0),
		complex(1.0/720, 0),
		complex(1.0/5040, 0),
	}
	poly := ckks.NewPoly(coeffs)
	eval, _ := evaluator.EvaluatePoly(node1, poly, node1.Scale)
	elapsed := time.Since(start).Milliseconds()
	// fmt.Printf("time: %v ms\n", elapsed)
	fmt.Printf("%v\n", elapsed)
	return eval
}
 
func main() {
	var err error
	params, err := ckks.NewParametersFromLiteral(ckks.PN13QP218)
	if err != nil {
		panic(err)
	}
 
	// fmt.Printf("params\nLevel = %d, logQ = %d, Scale = %f bit\n", params.MaxLevel(), params.LogQ(), math.Log2(params.Scale()))
 
	encoder := ckks.NewEncoder(params)
	kgen := ckks.NewKeyGenerator(params)
	sk, pk := kgen.GenKeyPair()
	rlk := kgen.GenRelinearizationKey(sk, 2)
 
	indices := make([]int, 0)
	for _, r := range DecompPo2(params.Slots() - 1) {
		indices = append(indices, r)
	}
	// fmt.Printf("Indices: %v\n", indices)
	rtks := kgen.GenRotationKeysForRotations(indices, false, sk)
	// fmt.Printf("Rotation keys: %v\n", rtks)
	// k, _ := rtks.GetRotationKey(1)
	// fmt.Printf("Rotation key: %v\n", k)
 
	encryptor := ckks.NewEncryptor(params, pk)
	decryptor := ckks.NewDecryptor(params, sk)
	evaluator := ckks.NewEvaluator(params, rlwe.EvaluationKey{Rlk: rlk, Rtks: rtks})
 
	// var out *ckks.Ciphertext
 
	v1 := make([]complex128, params.Slots())
	p1 := encoder.EncodeNew(v1, params.MaxLevel(), params.DefaultScale(), params.LogSlots())
	node1 := encryptor.EncryptNew(p1)
 
	Kernel(&params, encoder, encryptor, decryptor, evaluator, node1)
 
	// Inspect("out", &params, encoder, encryptor, decryptor, out)
}