Skip to content

Commit

Permalink
feat: Add Unsafe method for faster InvFFT
Browse files Browse the repository at this point in the history
  • Loading branch information
sp301415 committed Dec 10, 2023
1 parent 62bbfe6 commit a17ec3b
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 29 deletions.
115 changes: 107 additions & 8 deletions math/poly/fourier_transform.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package poly

// ToFourierPoly transforms Poly to FourierPoly and returns it.
// ToFourierPoly transforms Poly to FourierPoly, and returns it.
func (f *FourierEvaluator[T]) ToFourierPoly(p Poly[T]) FourierPoly {
fp := NewFourierPoly(f.degree)
f.ToFourierPolyAssign(p, fp)
Expand Down Expand Up @@ -43,7 +43,7 @@ func (f *FourierEvaluator[T]) ToFourierPolyAssign(p Poly[T], fpOut FourierPoly)
fftInPlace(fpOut.Coeffs, f.wNj)
}

// ToScaledFourierPoly transforms Poly to FourierPoly and returns it.
// ToScaledFourierPoly transforms Poly to FourierPoly, and returns it.
// Each coefficients are scaled by 1 / 2^sizeT.
func (f *FourierEvaluator[T]) ToScaledFourierPoly(p Poly[T]) FourierPoly {
fp := NewFourierPoly(f.degree)
Expand Down Expand Up @@ -88,7 +88,7 @@ func (f *FourierEvaluator[T]) ToScaledFourierPolyAssign(p Poly[T], fp FourierPol
fftInPlace(fp.Coeffs, f.wNj)
}

// ToStandardPoly transforms FourierPoly to Poly and returns it.
// ToStandardPoly transforms FourierPoly to Poly, and returns it.
func (f *FourierEvaluator[T]) ToStandardPoly(fp FourierPoly) Poly[T] {
p := New[T](f.degree)
f.ToStandardPolyAssign(fp, p)
Expand All @@ -109,7 +109,23 @@ func (f *FourierEvaluator[T]) ToStandardPolyAssign(fp FourierPoly, pOut Poly[T])
}
}

// ToStandardPolyAddAssign transforms FourierPoly to Poly and adds it to pOut.
// ToStandardPolyAssignUnsafe transforms FourierPoly to Poly, and writes it to pOut.
//
// This method is slightly faster than ToStandardPolyAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToStandardPolyAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistInPlace(fp.Coeffs, f.w2NjInv)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] = T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] = -T(int64(imag(fp.Coeffs[j])))
}
}

// ToStandardPolyAddAssign transforms FourierPoly to Poly, and adds it to pOut.
func (f *FourierEvaluator[T]) ToStandardPolyAddAssign(fp FourierPoly, pOut Poly[T]) {
N := f.degree

Expand All @@ -123,7 +139,23 @@ func (f *FourierEvaluator[T]) ToStandardPolyAddAssign(fp FourierPoly, pOut Poly[
}
}

// ToStandardPolySubAssign transforms FourierPoly to Poly and subtracts it from pOut.
// ToStandardPolyAddAssignUnsafe transforms FourierPoly to Poly, and adds it to pOut.
//
// This method is slightly faster than ToStandardPolyAddAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToStandardPolyAddAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistInPlace(fp.Coeffs, f.w2NjInv)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] += T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] += -T(int64(imag(fp.Coeffs[j])))
}
}

// ToStandardPolySubAssign transforms FourierPoly to Poly, and subtracts it from pOut.
func (f *FourierEvaluator[T]) ToStandardPolySubAssign(fp FourierPoly, pOut Poly[T]) {
N := f.degree

Expand All @@ -137,7 +169,23 @@ func (f *FourierEvaluator[T]) ToStandardPolySubAssign(fp FourierPoly, pOut Poly[
}
}

// ToScaledStandardPoly transforms FourierPoly to Poly and returns it.
// ToStandardPolySubAssignUnsafe transforms FourierPoly to Poly, and subtracts it from pOut.
//
// This method is slightly faster than ToStandardPolySubAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToStandardPolySubAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistInPlace(fp.Coeffs, f.w2NjInv)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] -= T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] -= -T(int64(imag(fp.Coeffs[j])))
}
}

// ToScaledStandardPoly transforms FourierPoly to Poly, and returns it.
// Each coefficients are scaled by 2^sizeT.
func (f *FourierEvaluator[T]) ToScaledStandardPoly(fp FourierPoly) Poly[T] {
p := New[T](f.degree)
Expand All @@ -160,7 +208,24 @@ func (f *FourierEvaluator[T]) ToScaledStandardPolyAssign(fp FourierPoly, pOut Po
}
}

// ToScaledStandardPolyAddAssign transforms FourierPoly to Poly and adds it to pOut.
// ToScaledStandardPolyAssignUnsafe transforms FourierPoly to Poly, and writes it to pOut.
// Each coefficients are scaled by 2^sizeT.
//
// This method is slightly faster than ToScaledStandardPolyAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToScaledStandardPolyAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistAndScaleInPlace(fp.Coeffs, f.w2NjInv, f.maxT)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] = T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] = -T(int64(imag(fp.Coeffs[j])))
}
}

// ToScaledStandardPolyAddAssign transforms FourierPoly to Poly, and adds it to pOut.
// Each coefficients are scaled by 2^sizeT.
func (f *FourierEvaluator[T]) ToScaledStandardPolyAddAssign(fp FourierPoly, pOut Poly[T]) {
N := f.degree
Expand All @@ -175,7 +240,24 @@ func (f *FourierEvaluator[T]) ToScaledStandardPolyAddAssign(fp FourierPoly, pOut
}
}

// ToScaledStandardPolySubAssign transforms FourierPoly to Poly and subtracts it from pOut.
// ToScaledStandardPolyAddAssignUnsafe transforms FourierPoly to Poly, and adds it to pOut.
// Each coefficients are scaled by 2^sizeT.
//
// This method is slightly faster than ToScaledStandardPolyAddAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToScaledStandardPolyAddAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistAndScaleInPlace(fp.Coeffs, f.w2NjInv, f.maxT)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] += T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] += -T(int64(imag(fp.Coeffs[j])))
}
}

// ToScaledStandardPolySubAssign transforms FourierPoly to Poly, and subtracts it from pOut.
// Each coefficients are scaled by 2^sizeT.
func (f *FourierEvaluator[T]) ToScaledStandardPolySubAssign(fp FourierPoly, pOut Poly[T]) {
N := f.degree
Expand All @@ -189,3 +271,20 @@ func (f *FourierEvaluator[T]) ToScaledStandardPolySubAssign(fp FourierPoly, pOut
pOut.Coeffs[j+N/2] -= -T(int64(imag(f.buffer.fpInv.Coeffs[j])))
}
}

// ToScaledStandardPolySubAssignUnsafe transforms FourierPoly to Poly, and subtracts it from pOut.
// Each coefficients are scaled by 2^sizeT.
//
// This method is slightly faster than ToScaledStandardPolySubAssign, but it modifies fp directly.
// Use it only if you don't need fp after this method (e.g. fp is a buffer).
func (f *FourierEvaluator[T]) ToScaledStandardPolySubAssignUnsafe(fp FourierPoly, pOut Poly[T]) {
N := f.degree

invFFTInPlace(fp.Coeffs, f.wNjInv)
unTwistAndScaleInPlace(fp.Coeffs, f.w2NjInv, f.maxT)

for j := 0; j < N/2; j++ {
pOut.Coeffs[j] -= T(int64(real(fp.Coeffs[j])))
pOut.Coeffs[j+N/2] -= -T(int64(imag(fp.Coeffs[j])))
}
}
18 changes: 9 additions & 9 deletions tfhe/glwe_enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ func (e *Encryptor[T]) mulFourierGLWEKeyAssign(p0 poly.Poly[T], fp poly.FourierP
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, pOut)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, pOut)

for i := 0; i < e.Parameters.polyDegree; i++ {
e.buffer.pSplit.Coeffs[i] = (p0.Coeffs[i] >> bits) & mask
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] += e.buffer.pSplit.Coeffs[i] << bits
}
Expand All @@ -37,7 +37,7 @@ func (e *Encryptor[T]) mulFourierGLWEKeyAssign(p0 poly.Poly[T], fp poly.FourierP
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] += (e.buffer.pSplit.Coeffs[i] << bits) << bits
}
Expand All @@ -54,14 +54,14 @@ func (e *Encryptor[T]) mulFourierGLWEKeyAddAssign(p0 poly.Poly[T], fp poly.Fouri
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAddAssign(e.buffer.fpSplit, pOut)
e.FourierEvaluator.ToStandardPolyAddAssignUnsafe(e.buffer.fpSplit, pOut)

for i := 0; i < e.Parameters.polyDegree; i++ {
e.buffer.pSplit.Coeffs[i] = (p0.Coeffs[i] >> bits) & mask
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] += e.buffer.pSplit.Coeffs[i] << bits
}
Expand All @@ -75,7 +75,7 @@ func (e *Encryptor[T]) mulFourierGLWEKeyAddAssign(p0 poly.Poly[T], fp poly.Fouri
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] += (e.buffer.pSplit.Coeffs[i] << bits) << bits
}
Expand All @@ -92,14 +92,14 @@ func (e *Encryptor[T]) mulFourierGLWEKeySubAssign(p0 poly.Poly[T], fp poly.Fouri
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolySubAssign(e.buffer.fpSplit, pOut)
e.FourierEvaluator.ToStandardPolySubAssignUnsafe(e.buffer.fpSplit, pOut)

for i := 0; i < e.Parameters.polyDegree; i++ {
e.buffer.pSplit.Coeffs[i] = (p0.Coeffs[i] >> bits) & mask
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] -= e.buffer.pSplit.Coeffs[i] << bits
}
Expand All @@ -113,7 +113,7 @@ func (e *Encryptor[T]) mulFourierGLWEKeySubAssign(p0 poly.Poly[T], fp poly.Fouri
}
e.FourierEvaluator.ToFourierPolyAssign(e.buffer.pSplit, e.buffer.fpSplit)
e.FourierEvaluator.MulAssign(e.buffer.fpSplit, fp, e.buffer.fpSplit)
e.FourierEvaluator.ToStandardPolyAssign(e.buffer.fpSplit, e.buffer.pSplit)
e.FourierEvaluator.ToStandardPolyAssignUnsafe(e.buffer.fpSplit, e.buffer.pSplit)
for i := 0; i < e.Parameters.polyDegree; i++ {
pOut.Coeffs[i] -= (e.buffer.pSplit.Coeffs[i] << bits) << bits
}
Expand Down
24 changes: 12 additions & 12 deletions tfhe/product.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (e *Evaluator[T]) GadgetProductAssign(ctFourierGLev FourierGLevCiphertext[T
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand All @@ -40,7 +40,7 @@ func (e *Evaluator[T]) GadgetProductAddAssign(ctFourierGLev FourierGLevCiphertex
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAddAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAddAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand All @@ -56,7 +56,7 @@ func (e *Evaluator[T]) GadgetProductSubAssign(ctFourierGLev FourierGLevCiphertex
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolySubAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolySubAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand All @@ -69,7 +69,7 @@ func (e *Evaluator[T]) GadgetProductFourierDecomposedAssign(ctFourierGLev Fourie
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand All @@ -82,7 +82,7 @@ func (e *Evaluator[T]) GadgetProductFourierDecomposedAddAssign(ctFourierGLev Fou
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAddAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAddAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand All @@ -95,7 +95,7 @@ func (e *Evaluator[T]) GadgetProductFourierDecomposedSubAssign(ctFourierGLev Fou
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolySubAssign(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolySubAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctOut.Value[i])
}
}

Expand Down Expand Up @@ -126,7 +126,7 @@ func (e *Evaluator[T]) ExternalProductAssign(ctFourierGGSW FourierGGSWCiphertext
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand All @@ -149,7 +149,7 @@ func (e *Evaluator[T]) ExternalProductAddAssign(ctFourierGGSW FourierGGSWCiphert
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAddAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAddAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand All @@ -172,7 +172,7 @@ func (e *Evaluator[T]) ExternalProductSubAssign(ctFourierGGSW FourierGGSWCiphert
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolySubAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolySubAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand All @@ -199,7 +199,7 @@ func (e *Evaluator[T]) ExternalProductFourierDecomposedAssign(ctFourierGGSW Four
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand All @@ -218,7 +218,7 @@ func (e *Evaluator[T]) ExternalProductFourierDecomposedAddAssign(ctFourierGGSW F
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolyAddAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolyAddAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand All @@ -237,7 +237,7 @@ func (e *Evaluator[T]) ExternalProductFourierDecomposedSubAssign(ctFourierGGSW F
}

for i := 0; i < e.Parameters.glweDimension+1; i++ {
e.FourierEvaluator.ToScaledStandardPolySubAssign(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
e.FourierEvaluator.ToScaledStandardPolySubAssignUnsafe(e.buffer.ctFourierProd.Value[i], ctGLWEOut.Value[i])
}
}

Expand Down

0 comments on commit a17ec3b

Please sign in to comment.