From 590b9b2dc67161e87989cb07bc6a6d9f20d3110b Mon Sep 17 00:00:00 2001 From: Hwang In Tak Date: Sun, 26 Jan 2025 15:29:20 +0000 Subject: [PATCH] fix: Fix asm_decompose --- internal/asmgen/decompose.go | 28 ++++++++++++++++++++++++++-- math/poly/asm_convert.go | 4 +--- math/poly/asm_convert_amd64.go | 4 +--- tfhe/asm_decompose_amd64.s | 24 ++++++++++++++++++------ 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/internal/asmgen/decompose.go b/internal/asmgen/decompose.go index f9580dd..5511757 100644 --- a/internal/asmgen/decompose.go +++ b/internal/asmgen/decompose.go @@ -79,9 +79,21 @@ func decomposePolyAssignUint32AVX2() { SUBQ(Imm(3), jj) Label("level_loop_end") - CMPQ(j, Imm(0)) + CMPQ(j, Imm(1)) JGE(LabelRef("level_loop_body")) + u = YMM() + VANDPD(baseMask, c, u) + + uCarry = YMM() + VANDPD(baseHalf, u, uCarry) + VPSLLD(Imm(1), uCarry, uCarry) + VPSUBD(uCarry, u, u) + + decomposedOut0 := GP64() + MOVQ(Mem{Base: decomposedOut}, decomposedOut0) + VMOVDQU(u, Mem{Base: decomposedOut0, Index: i, Scale: 4}) + ADDQ(Imm(8), i) Label("N_loop_end") @@ -159,9 +171,21 @@ func decomposePolyAssignUint64AVX2() { SUBQ(Imm(3), jj) Label("level_loop_end") - CMPQ(j, Imm(0)) + CMPQ(j, Imm(1)) JGE(LabelRef("level_loop_body")) + u = YMM() + VANDPD(baseMask, c, u) + + uCarry = YMM() + VANDPD(baseHalf, u, uCarry) + VPSLLQ(Imm(1), uCarry, uCarry) + VPSUBQ(uCarry, u, u) + + decomposedOut0 := GP64() + MOVQ(Mem{Base: decomposedOut}, decomposedOut0) + VMOVDQU(u, Mem{Base: decomposedOut0, Index: i, Scale: 8}) + ADDQ(Imm(4), i) Label("N_loop_end") diff --git a/math/poly/asm_convert.go b/math/poly/asm_convert.go index 7ce4885..e054247 100644 --- a/math/poly/asm_convert.go +++ b/math/poly/asm_convert.go @@ -92,9 +92,7 @@ func convertPolyToFourierPolyAssign[T num.Integer](p []T, fpOut []float64) { // floatModQInPlace computes coeffs mod Q in place. func floatModQInPlace(coeffs []float64, Q float64) { for i := range coeffs { - cQuo := coeffs[i] / Q - cRem := cQuo - math.Round(cQuo) - coeffs[i] = math.Round(cRem * Q) + coeffs[i] = math.Round(coeffs[i] - Q*math.Round(coeffs[i]/Q)) } } diff --git a/math/poly/asm_convert_amd64.go b/math/poly/asm_convert_amd64.go index b79e6a6..e90e765 100644 --- a/math/poly/asm_convert_amd64.go +++ b/math/poly/asm_convert_amd64.go @@ -116,9 +116,7 @@ func floatModQInPlace(coeffs []float64, Q float64) { } for i := range coeffs { - cQuo := coeffs[i] / Q - cRem := cQuo - math.Round(cQuo) - coeffs[i] = math.Round(cRem * Q) + coeffs[i] = math.Round(coeffs[i] - Q*math.Round(coeffs[i]/Q)) } } diff --git a/tfhe/asm_decompose_amd64.s b/tfhe/asm_decompose_amd64.s index 7d7c65e..0699e32 100644 --- a/tfhe/asm_decompose_amd64.s +++ b/tfhe/asm_decompose_amd64.s @@ -52,9 +52,15 @@ level_loop_body: SUBQ $0x03, R8 level_loop_end: - CMPQ DI, $0x00 - JGE level_loop_body - ADDQ $0x08, SI + CMPQ DI, $0x01 + JGE level_loop_body + VANDPD Y4, Y6, Y6 + VANDPD Y0, Y6, Y7 + VPSLLD $0x01, Y7, Y7 + VPSUBD Y7, Y6, Y6 + MOVQ (CX), DI + VMOVDQU Y6, (DI)(SI*4) + ADDQ $0x08, SI N_loop_end: CMPQ SI, DX @@ -106,9 +112,15 @@ level_loop_body: SUBQ $0x03, R8 level_loop_end: - CMPQ DI, $0x00 - JGE level_loop_body - ADDQ $0x04, SI + CMPQ DI, $0x01 + JGE level_loop_body + VANDPD Y4, Y6, Y6 + VANDPD Y0, Y6, Y7 + VPSLLQ $0x01, Y7, Y7 + VPSUBQ Y7, Y6, Y6 + MOVQ (CX), DI + VMOVDQU Y6, (DI)(SI*8) + ADDQ $0x04, SI N_loop_end: CMPQ SI, DX