diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84e9fa4..3e79beb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [main] jobs: - ci: + tests: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -26,6 +26,22 @@ jobs: - name: Tests run: go test -v ./... + - name: Fuzz Uint64 addition + run: go test -v -fuzz=OverflowAddUint64 -fuzztime=60s ./cashu + + - name: Fuzz Uint64 subtraction + run: go test -v -fuzz=FuzzUnderflowSubUint64 -fuzztime=60s ./cashu + + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: 1.22.2 + - name: Integration Tests run: go test -v --tags=integration ./mint - - run: go test -v --tags=integration ./wallet + - run: go test -v --tags=integration ./wallet \ No newline at end of file diff --git a/cashu/cashu.go b/cashu/cashu.go index 9c747f4..e0de317 100644 --- a/cashu/cashu.go +++ b/cashu/cashu.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/fxamacker/cbor/v2" @@ -33,9 +34,10 @@ func (unit Unit) String() string { } var ( - ErrInvalidTokenV3 = errors.New("invalid V3 token") - ErrInvalidTokenV4 = errors.New("invalid V4 token") - ErrInvalidUnit = errors.New("invalid unit") + ErrInvalidTokenV3 = errors.New("invalid V3 token") + ErrInvalidTokenV4 = errors.New("invalid V4 token") + ErrInvalidUnit = errors.New("invalid unit") + ErrAmountOverflows = errors.New("amount overflows") ) // Cashu BlindedMessage. See https://github.com/cashubtc/nuts/blob/main/00.md#blindedmessage @@ -79,6 +81,38 @@ func (bm BlindedMessages) Amount() uint64 { return totalAmount } +// AmountChecked returns the total amount in the blinded messages +// and an error if it overflows +func (bm BlindedMessages) AmountChecked() (uint64, error) { + var totalAmount uint64 = 0 + overflows := false + for _, msg := range bm { + totalAmount, overflows = OverflowAddUint64(totalAmount, msg.Amount) + if overflows { + return 0, ErrAmountOverflows + } + } + + return totalAmount, nil +} + +// OverflowAddUint64 adds two uint64 and checks if that results in an overflow +func OverflowAddUint64(a, b uint64) (uint64, bool) { + sum := a + b + if sum < a || sum < b { + return math.MaxUint64, true + } + return sum, false +} + +// UnderflowSubUint64 subtracts two uint64 and checks if that results in an underflow +func UnderflowSubUint64(a, b uint64) (uint64, bool) { + if b > a { + return 0, true + } + return a - b, false +} + // Cashu BlindedSignature. See https://github.com/cashubtc/nuts/blob/main/00.md#blindsignature type BlindedSignature struct { Amount uint64 `json:"amount"` diff --git a/cashu/cashu_test.go b/cashu/cashu_test.go index 507cc43..baa5edc 100644 --- a/cashu/cashu_test.go +++ b/cashu/cashu_test.go @@ -2,10 +2,182 @@ package cashu import ( "encoding/hex" + "math" + "math/big" "reflect" "testing" ) +func TestAmountChecked(t *testing.T) { + split := AmountSplit(math.MaxUint64) + overflowBlindedMessages := make(BlindedMessages, len(split)+1) + for i, amount := range split { + overflowBlindedMessages[i] = BlindedMessage{Amount: amount} + } + overflowBlindedMessages[len(split)] = BlindedMessage{Amount: 4} + + tests := []struct { + blindedMessages BlindedMessages + expectedAmount uint64 + expectedErr error + }{ + { + blindedMessages: BlindedMessages{ + BlindedMessage{Amount: 2}, + BlindedMessage{Amount: 4}, + BlindedMessage{Amount: 8}, + BlindedMessage{Amount: 64}, + }, + expectedAmount: 78, + expectedErr: nil, + }, + { + blindedMessages: overflowBlindedMessages, + expectedAmount: 0, + expectedErr: ErrAmountOverflows, + }, + } + + for _, test := range tests { + totalAmount, err := test.blindedMessages.AmountChecked() + if totalAmount != test.expectedAmount { + t.Fatalf("expected total amount of '%v' but got '%v'", test.expectedAmount, totalAmount) + } + + if err != test.expectedErr { + t.Fatalf("expected error '%v' but got '%v'", test.expectedErr, err) + } + } +} + +func TestOverflowAddUint64(t *testing.T) { + tests := []struct { + a uint64 + b uint64 + expectedUint64 uint64 + expectedOverflow bool + }{ + { + a: 21, + b: 42, + expectedUint64: 63, + expectedOverflow: false, + }, + { + a: math.MaxUint64 - 5, + b: 10, + expectedUint64: math.MaxUint64, + expectedOverflow: true, + }, + } + + for _, test := range tests { + result, overflow := OverflowAddUint64(test.a, test.b) + if result != test.expectedUint64 { + t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result) + } + + if overflow != test.expectedOverflow { + t.Fatalf("expected overflow '%v' but got '%v'", test.expectedOverflow, overflow) + } + } +} + +func FuzzOverflowAddUint64(f *testing.F) { + cases := [][2]uint64{ + {21, 42}, + {math.MaxUint64, 10}, + } + for _, seed := range cases { + f.Add(seed[0], seed[1]) + } + + f.Fuzz(func(t *testing.T, a uint64, b uint64) { + bigA := new(big.Int).SetUint64(a) + bigB := new(big.Int).SetUint64(b) + bigA.Add(bigA, bigB) + + result, overflow := OverflowAddUint64(a, b) + // IsUint64 reports whether the number can be represented as uint64 + if bigA.IsUint64() { + uint64Result := bigA.Uint64() + if uint64Result != result { + t.Errorf("a = %v and b = %v. expected result %v but got %v", a, b, uint64Result, result) + } + } else { + // if result from addition cannot be represented as uint64, + // then function should return overflow == true + if !overflow { + t.Error("addition is above max uint64 but did not return overflow") + } + } + }) +} + +func TestUnderflowSubUint64(t *testing.T) { + tests := []struct { + a uint64 + b uint64 + expectedUint64 uint64 + expectedUnderflow bool + }{ + { + a: 42, + b: 21, + expectedUint64: 21, + expectedUnderflow: false, + }, + { + a: 10, + b: 210, + expectedUint64: 0, + expectedUnderflow: true, + }, + } + + for _, test := range tests { + result, underflow := UnderflowSubUint64(test.a, test.b) + if result != test.expectedUint64 { + t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result) + } + + if underflow != test.expectedUnderflow { + t.Fatalf("expected overflow '%v' but got '%v'", test.expectedUnderflow, underflow) + } + } +} + +func FuzzUnderflowSubUint64(f *testing.F) { + cases := [][2]uint64{ + {42, 21}, + {10, 210}, + } + for _, seed := range cases { + f.Add(seed[0], seed[1]) + } + + f.Fuzz(func(t *testing.T, a uint64, b uint64) { + bigA := new(big.Int).SetUint64(a) + bigB := new(big.Int).SetUint64(b) + bigA.Sub(bigA, bigB) + + result, underflow := UnderflowSubUint64(a, b) + // IsUint64 reports whether the number can be represented as uint64 + if bigA.IsUint64() { + uint64Result := bigA.Uint64() + if uint64Result != result { + t.Errorf("a = %v and b = %v. expected result %v but got %v", a, b, uint64Result, result) + } + } else { + // if result from sub cannot be represented as uint64, + // then function should return underflow == true + if !underflow { + t.Error("subtraction is below 0 but did not return underflow") + } + } + }) +} + func TestDecodeTokenV4(t *testing.T) { keysetIdBytes, _ := hex.DecodeString("00ad268c4d1f5826") Cbytes, _ := hex.DecodeString("038618543ffb6b8695df4ad4babcde92a34a96bdcd97dcee0d7ccf98d472126792") diff --git a/mint/mint.go b/mint/mint.go index 99caeff..a06fba5 100644 --- a/mint/mint.go +++ b/mint/mint.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "log/slog" - "math" "os" "path/filepath" "reflect" @@ -400,15 +399,9 @@ func (m *Mint) MintTokens(mintTokensRequest nut04.PostMintBolt11Request) (cashu. } blindedMessages := mintTokensRequest.Outputs - var blindedMessagesAmount uint64 - B_s := make([]string, len(blindedMessages)) - var overflows bool - for i, bm := range blindedMessages { - blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount) - if overflows { - return cashu.InvalidBlindedMessageAmount - } - B_s[i] = bm.B_ + blindedMessagesAmount, err := blindedMessages.AmountChecked() + if err != nil { + return cashu.InvalidBlindedMessageAmount } // verify that amount from blinded messages is enough @@ -417,6 +410,11 @@ func (m *Mint) MintTokens(mintTokensRequest nut04.PostMintBolt11Request) (cashu. return cashu.OutputsOverQuoteAmountErr } + B_s := make([]string, len(blindedMessages)) + for i, bm := range blindedMessages { + B_s[i] = bm.B_ + } + sigs, err := m.db.GetBlindSignatures(B_s) if err != nil { errmsg := fmt.Sprintf("error getting blind signatures from db: %v", err) @@ -500,19 +498,18 @@ func (m *Mint) Swap(proofs cashu.Proofs, blindedMessages cashu.BlindedMessages) Ys[i] = Yhex } - var blindedMessagesAmount uint64 + blindedMessagesAmount, err := blindedMessages.AmountChecked() + if err != nil { + return nil, cashu.InvalidBlindedMessageAmount + } + B_s := make([]string, len(blindedMessages)) - var overflows bool for i, bm := range blindedMessages { - blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount) - if overflows { - return nil, cashu.InvalidBlindedMessageAmount - } B_s[i] = bm.B_ } fees := uint64(m.TransactionFees(proofs)) - proofsMinusFees, underflow := underflowSubUint64(proofsAmount, fees) + proofsMinusFees, underflow := cashu.UnderflowSubUint64(proofsAmount, fees) if underflow { return nil, cashu.InvalidProofAmount } @@ -520,8 +517,7 @@ func (m *Mint) Swap(proofs cashu.Proofs, blindedMessages cashu.BlindedMessages) return nil, cashu.InsufficientProofsAmount } - err := m.verifyProofs(proofs, Ys) - if err != nil { + if err := m.verifyProofs(proofs, Ys); err != nil { return nil, err } @@ -1464,22 +1460,6 @@ func (m *Mint) signBlindedMessages(blindedMessages cashu.BlindedMessages) (cashu return blindedSignatures, nil } -// overflowAddUint64 adds two uint64 and checks if that results in an overflow -func overflowAddUint64(a, b uint64) (uint64, bool) { - sum := a + b - if sum < a || sum < b { - return math.MaxUint64, true - } - return sum, false -} - -func underflowSubUint64(a, b uint64) (uint64, bool) { - if b > a { - return 0, true - } - return a - b, false -} - // requestInvoice requests an invoice from the Lightning backend // for the given amount func (m *Mint) requestInvoice(amount uint64) (*lightning.Invoice, error) { diff --git a/mint/mint_test.go b/mint/mint_test.go deleted file mode 100644 index 6ee59e3..0000000 --- a/mint/mint_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package mint - -import ( - "math" - "testing" -) - -func TestOverflowAddUint64(t *testing.T) { - tests := []struct { - a uint64 - b uint64 - expectedUint64 uint64 - expectedOverflow bool - }{ - { - a: 21, - b: 42, - expectedUint64: 63, - expectedOverflow: false, - }, - { - a: math.MaxUint64 - 5, - b: 10, - expectedUint64: math.MaxUint64, - expectedOverflow: true, - }, - } - - for _, test := range tests { - result, overflow := overflowAddUint64(test.a, test.b) - if result != test.expectedUint64 { - t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result) - } - - if overflow != test.expectedOverflow { - t.Fatalf("expected overflow '%v' but got '%v'", test.expectedOverflow, overflow) - } - } -} - -func TestUnderflowSubUint64(t *testing.T) { - tests := []struct { - a uint64 - b uint64 - expectedUint64 uint64 - expectedUnderflow bool - }{ - { - a: 42, - b: 21, - expectedUint64: 21, - expectedUnderflow: false, - }, - { - a: 10, - b: 210, - expectedUint64: 0, - expectedUnderflow: true, - }, - } - - for _, test := range tests { - result, underflow := underflowSubUint64(test.a, test.b) - if result != test.expectedUint64 { - t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result) - } - - if underflow != test.expectedUnderflow { - t.Fatalf("expected overflow '%v' but got '%v'", test.expectedUnderflow, underflow) - } - } -}