Skip to content

Commit

Permalink
mint: refactor overflow checks
Browse files Browse the repository at this point in the history
  • Loading branch information
elnosh committed Feb 13, 2025
1 parent 7e7ee7b commit 3667bdb
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 21 deletions.
1 change: 1 addition & 0 deletions cashu/cashu.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ var (
PaymentMethodNotSupportedErr = Error{Detail: "payment method not supported", Code: PaymentMethodErrCode}
UnitNotSupportedErr = Error{Detail: "unit not supported", Code: UnitErrCode}
InvalidBlindedMessageAmount = Error{Detail: "invalid amount in blinded message", Code: StandardErrCode}
InvalidProofAmount = Error{Detail: "invalid amount in proof", Code: StandardErrCode}
BlindedMessageAlreadySigned = Error{Detail: "blinded message already signed", Code: BlindedMessageAlreadySignedErrCode}
MintQuoteRequestNotPaid = Error{Detail: "quote request has not been paid", Code: MintQuoteRequestNotPaidErrCode}
MintQuoteAlreadyIssued = Error{Detail: "quote already issued", Code: MintQuoteAlreadyIssuedErrCode}
Expand Down
55 changes: 34 additions & 21 deletions mint/mint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"log/slog"
"math"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -401,21 +402,17 @@ func (m *Mint) MintTokens(mintTokensRequest nut04.PostMintBolt11Request) (cashu.
blindedMessages := mintTokensRequest.Outputs
var blindedMessagesAmount uint64
B_s := make([]string, len(blindedMessages))
var overflows bool
for i, bm := range blindedMessages {
blindedMessagesAmount += bm.Amount
B_s[i] = bm.B_
}

if len(blindedMessages) > 0 {
for _, msg := range blindedMessages {
if blindedMessagesAmount < msg.Amount {
return cashu.InvalidBlindedMessageAmount
}
blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount)
if overflows {
return cashu.InvalidBlindedMessageAmount
}
B_s[i] = bm.B_
}

// verify that amount from blinded messages is less
// than quote amount
// verify that amount from blinded messages is enough
// for quote amount
if blindedMessagesAmount > mintQuote.Amount {
return cashu.OutputsOverQuoteAmountErr
}
Expand Down Expand Up @@ -505,21 +502,21 @@ func (m *Mint) Swap(proofs cashu.Proofs, blindedMessages cashu.BlindedMessages)

var blindedMessagesAmount uint64
B_s := make([]string, len(blindedMessages))
var overflows bool
for i, bm := range blindedMessages {
blindedMessagesAmount += bm.Amount
blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount)
if overflows {
return nil, cashu.InvalidBlindedMessageAmount
}
B_s[i] = bm.B_
}

// check overflow
if len(blindedMessages) > 0 {
for _, msg := range blindedMessages {
if blindedMessagesAmount < msg.Amount {
return nil, cashu.InvalidBlindedMessageAmount
}
}
fees := uint64(m.TransactionFees(proofs))
proofsMinusFees, underflow := underflowSubUint64(proofsAmount, fees)
if underflow {
return nil, cashu.InvalidProofAmount
}
fees := m.TransactionFees(proofs)
if proofsAmount-uint64(fees) < blindedMessagesAmount {
if proofsMinusFees < blindedMessagesAmount {
return nil, cashu.InsufficientProofsAmount
}

Expand Down Expand Up @@ -1467,6 +1464,22 @@ func (m *Mint) signBlindedMessages(blindedMessages cashu.BlindedMessages) (cashu
return blindedSignatures, nil
}

// overflowAddUint64 adds two uint64 and checks if that results in an overflow
func overflowAddUint64(a, b uint64) (uint64, bool) {
sum := a + b
if sum < a || sum < b {
return math.MaxUint64, true
}
return sum, false
}

func underflowSubUint64(a, b uint64) (uint64, bool) {
if b > a {
return 0, true
}
return a - b, false
}

// requestInvoice requests an invoice from the Lightning backend
// for the given amount
func (m *Mint) requestInvoice(amount uint64) (*lightning.Invoice, error) {
Expand Down
21 changes: 21 additions & 0 deletions mint/mint_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"flag"
"log"
"math"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -249,6 +250,17 @@ func TestMintTokens(t *testing.T) {
t.Fatalf("expected error '%v' but got '%v' instead", cashu.UnknownKeysetErr, err)
}

// test overflow in blinded messages amount
overflowBlindedMessages, _, _, err := testutils.CreateBlindedMessages(math.MaxUint64, keyset)
bms, _, _, err := testutils.CreateBlindedMessages(mintAmount, keyset)
overflowBlindedMessages = append(overflowBlindedMessages, bms...)
mintTokensRequest = nut04.PostMintBolt11Request{Quote: mintQuoteResponse.Id, Outputs: overflowBlindedMessages}
_, err = testMint.MintTokens(mintTokensRequest)
if !errors.Is(err, cashu.InvalidBlindedMessageAmount) {
t.Fatalf("expected error '%v' but got '%v' instead", cashu.InvalidBlindedMessageAmount, err)
}

// valid mint request
mintTokensRequest = nut04.PostMintBolt11Request{Quote: mintQuoteResponse.Id, Outputs: blindedMessages}
_, err = testMint.MintTokens(mintTokensRequest)
if err != nil {
Expand Down Expand Up @@ -354,6 +366,15 @@ func TestSwap(t *testing.T) {
t.Fatalf("expected error '%v' but got '%v' instead", cashu.InsufficientProofsAmount, err)
}

// test overflow in blinded messages amount
overflowBlindedMessages, _, _, err := testutils.CreateBlindedMessages(math.MaxUint64, keyset)
bms, _, _, err := testutils.CreateBlindedMessages(amount, keyset)
overflowBlindedMessages = append(overflowBlindedMessages, bms...)
_, err = testMint.Swap(proofs, overflowBlindedMessages)
if !errors.Is(err, cashu.InvalidBlindedMessageAmount) {
t.Fatalf("expected error '%v' but got '%v' instead", cashu.InvalidBlindedMessageAmount, err)
}

// test with duplicates in proofs list passed
proofsLen := len(proofs)
duplicateProofs := make(cashu.Proofs, proofsLen)
Expand Down
72 changes: 72 additions & 0 deletions mint/mint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mint

import (
"math"
"testing"
)

func TestOverflowAddUint64(t *testing.T) {
tests := []struct {
a uint64
b uint64
expectedUint64 uint64
expectedOverflow bool
}{
{
a: 21,
b: 42,
expectedUint64: 63,
expectedOverflow: false,
},
{
a: math.MaxUint64 - 5,
b: 10,
expectedUint64: math.MaxUint64,
expectedOverflow: true,
},
}

for _, test := range tests {
result, overflow := overflowAddUint64(test.a, test.b)
if result != test.expectedUint64 {
t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result)
}

if overflow != test.expectedOverflow {
t.Fatalf("expected overflow '%v' but got '%v'", test.expectedOverflow, overflow)
}
}
}

func TestUnderflowSubUint64(t *testing.T) {
tests := []struct {
a uint64
b uint64
expectedUint64 uint64
expectedUnderflow bool
}{
{
a: 42,
b: 21,
expectedUint64: 21,
expectedUnderflow: false,
},
{
a: 10,
b: 210,
expectedUint64: 0,
expectedUnderflow: true,
},
}

for _, test := range tests {
result, underflow := underflowSubUint64(test.a, test.b)
if result != test.expectedUint64 {
t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result)
}

if underflow != test.expectedUnderflow {
t.Fatalf("expected overflow '%v' but got '%v'", test.expectedUnderflow, underflow)
}
}
}

0 comments on commit 3667bdb

Please sign in to comment.