Skip to content

Commit

Permalink
AmountChecked method to blinded messages and move functions to cashu …
Browse files Browse the repository at this point in the history
…package
  • Loading branch information
elnosh committed Feb 14, 2025
1 parent 3667bdb commit 708d3ae
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 112 deletions.
20 changes: 18 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches: [main]

jobs:
ci:
tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -26,6 +26,22 @@ jobs:
- name: Tests
run: go test -v ./...

- name: Fuzz Uint64 addition
run: go test -v -fuzz=OverflowAddUint64 -fuzztime=60s ./cashu

- name: Fuzz Uint64 subtraction
run: go test -v -fuzz=FuzzUnderflowSubUint64 -fuzztime=60s ./cashu

integration-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 1.22.2

- name: Integration Tests
run: go test -v --tags=integration ./mint
- run: go test -v --tags=integration ./wallet
- run: go test -v --tags=integration ./wallet
40 changes: 37 additions & 3 deletions cashu/cashu.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"

"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/fxamacker/cbor/v2"
Expand All @@ -33,9 +34,10 @@ func (unit Unit) String() string {
}

var (
ErrInvalidTokenV3 = errors.New("invalid V3 token")
ErrInvalidTokenV4 = errors.New("invalid V4 token")
ErrInvalidUnit = errors.New("invalid unit")
ErrInvalidTokenV3 = errors.New("invalid V3 token")
ErrInvalidTokenV4 = errors.New("invalid V4 token")
ErrInvalidUnit = errors.New("invalid unit")
ErrAmountOverflows = errors.New("amount overflows")
)

// Cashu BlindedMessage. See https://github.com/cashubtc/nuts/blob/main/00.md#blindedmessage
Expand Down Expand Up @@ -79,6 +81,38 @@ func (bm BlindedMessages) Amount() uint64 {
return totalAmount
}

// AmountChecked returns the total amount in the blinded messages
// and an error if it overflows
func (bm BlindedMessages) AmountChecked() (uint64, error) {
var totalAmount uint64 = 0
overflows := false
for _, msg := range bm {
totalAmount, overflows = OverflowAddUint64(totalAmount, msg.Amount)
if overflows {
return 0, ErrAmountOverflows
}
}

return totalAmount, nil
}

// OverflowAddUint64 adds two uint64 and checks if that results in an overflow
func OverflowAddUint64(a, b uint64) (uint64, bool) {
sum := a + b
if sum < a || sum < b {
return math.MaxUint64, true
}
return sum, false
}

// UnderflowSubUint64 subtracts two uint64 and checks if that results in an underflow
func UnderflowSubUint64(a, b uint64) (uint64, bool) {
if b > a {
return 0, true
}
return a - b, false
}

// Cashu BlindedSignature. See https://github.com/cashubtc/nuts/blob/main/00.md#blindsignature
type BlindedSignature struct {
Amount uint64 `json:"amount"`
Expand Down
172 changes: 172 additions & 0 deletions cashu/cashu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,182 @@ package cashu

import (
"encoding/hex"
"math"
"math/big"
"reflect"
"testing"
)

func TestAmountChecked(t *testing.T) {
split := AmountSplit(math.MaxUint64)
overflowBlindedMessages := make(BlindedMessages, len(split)+1)
for i, amount := range split {
overflowBlindedMessages[i] = BlindedMessage{Amount: amount}
}
overflowBlindedMessages[len(split)] = BlindedMessage{Amount: 4}

tests := []struct {
blindedMessages BlindedMessages
expectedAmount uint64
expectedErr error
}{
{
blindedMessages: BlindedMessages{
BlindedMessage{Amount: 2},
BlindedMessage{Amount: 4},
BlindedMessage{Amount: 8},
BlindedMessage{Amount: 64},
},
expectedAmount: 78,
expectedErr: nil,
},
{
blindedMessages: overflowBlindedMessages,
expectedAmount: 0,
expectedErr: ErrAmountOverflows,
},
}

for _, test := range tests {
totalAmount, err := test.blindedMessages.AmountChecked()
if totalAmount != test.expectedAmount {
t.Fatalf("expected total amount of '%v' but got '%v'", test.expectedAmount, totalAmount)
}

if err != test.expectedErr {
t.Fatalf("expected error '%v' but got '%v'", test.expectedErr, err)
}
}
}

func TestOverflowAddUint64(t *testing.T) {
tests := []struct {
a uint64
b uint64
expectedUint64 uint64
expectedOverflow bool
}{
{
a: 21,
b: 42,
expectedUint64: 63,
expectedOverflow: false,
},
{
a: math.MaxUint64 - 5,
b: 10,
expectedUint64: math.MaxUint64,
expectedOverflow: true,
},
}

for _, test := range tests {
result, overflow := OverflowAddUint64(test.a, test.b)
if result != test.expectedUint64 {
t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result)
}

if overflow != test.expectedOverflow {
t.Fatalf("expected overflow '%v' but got '%v'", test.expectedOverflow, overflow)
}
}
}

func FuzzOverflowAddUint64(f *testing.F) {
cases := [][2]uint64{
{21, 42},
{math.MaxUint64, 10},
}
for _, seed := range cases {
f.Add(seed[0], seed[1])
}

f.Fuzz(func(t *testing.T, a uint64, b uint64) {
bigA := new(big.Int).SetUint64(a)
bigB := new(big.Int).SetUint64(b)
bigA.Add(bigA, bigB)

result, overflow := OverflowAddUint64(a, b)
// IsUint64 reports whether the number can be represented as uint64
if bigA.IsUint64() {
uint64Result := bigA.Uint64()
if uint64Result != result {
t.Errorf("a = %v and b = %v. expected result %v but got %v", a, b, uint64Result, result)
}
} else {
// if result from addition cannot be represented as uint64,
// then function should return overflow == true
if !overflow {
t.Error("addition is above max uint64 but did not return overflow")
}
}
})
}

func TestUnderflowSubUint64(t *testing.T) {
tests := []struct {
a uint64
b uint64
expectedUint64 uint64
expectedUnderflow bool
}{
{
a: 42,
b: 21,
expectedUint64: 21,
expectedUnderflow: false,
},
{
a: 10,
b: 210,
expectedUint64: 0,
expectedUnderflow: true,
},
}

for _, test := range tests {
result, underflow := UnderflowSubUint64(test.a, test.b)
if result != test.expectedUint64 {
t.Fatalf("expected result '%v' but got '%v'", test.expectedUint64, result)
}

if underflow != test.expectedUnderflow {
t.Fatalf("expected overflow '%v' but got '%v'", test.expectedUnderflow, underflow)
}
}
}

func FuzzUnderflowSubUint64(f *testing.F) {
cases := [][2]uint64{
{42, 21},
{10, 210},
}
for _, seed := range cases {
f.Add(seed[0], seed[1])
}

f.Fuzz(func(t *testing.T, a uint64, b uint64) {
bigA := new(big.Int).SetUint64(a)
bigB := new(big.Int).SetUint64(b)
bigA.Sub(bigA, bigB)

result, underflow := UnderflowSubUint64(a, b)
// IsUint64 reports whether the number can be represented as uint64
if bigA.IsUint64() {
uint64Result := bigA.Uint64()
if uint64Result != result {
t.Errorf("a = %v and b = %v. expected result %v but got %v", a, b, uint64Result, result)
}
} else {
// if result from sub cannot be represented as uint64,
// then function should return underflow == true
if !underflow {
t.Error("subtraction is below 0 but did not return underflow")
}
}
})
}

func TestDecodeTokenV4(t *testing.T) {
keysetIdBytes, _ := hex.DecodeString("00ad268c4d1f5826")
Cbytes, _ := hex.DecodeString("038618543ffb6b8695df4ad4babcde92a34a96bdcd97dcee0d7ccf98d472126792")
Expand Down
50 changes: 15 additions & 35 deletions mint/mint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"io"
"log/slog"
"math"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -400,15 +399,9 @@ func (m *Mint) MintTokens(mintTokensRequest nut04.PostMintBolt11Request) (cashu.
}

blindedMessages := mintTokensRequest.Outputs
var blindedMessagesAmount uint64
B_s := make([]string, len(blindedMessages))
var overflows bool
for i, bm := range blindedMessages {
blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount)
if overflows {
return cashu.InvalidBlindedMessageAmount
}
B_s[i] = bm.B_
blindedMessagesAmount, err := blindedMessages.AmountChecked()
if err != nil {
return cashu.InvalidBlindedMessageAmount
}

// verify that amount from blinded messages is enough
Expand All @@ -417,6 +410,11 @@ func (m *Mint) MintTokens(mintTokensRequest nut04.PostMintBolt11Request) (cashu.
return cashu.OutputsOverQuoteAmountErr
}

B_s := make([]string, len(blindedMessages))
for i, bm := range blindedMessages {
B_s[i] = bm.B_
}

sigs, err := m.db.GetBlindSignatures(B_s)
if err != nil {
errmsg := fmt.Sprintf("error getting blind signatures from db: %v", err)
Expand Down Expand Up @@ -500,28 +498,26 @@ func (m *Mint) Swap(proofs cashu.Proofs, blindedMessages cashu.BlindedMessages)
Ys[i] = Yhex
}

var blindedMessagesAmount uint64
blindedMessagesAmount, err := blindedMessages.AmountChecked()
if err != nil {
return nil, cashu.InvalidBlindedMessageAmount
}

B_s := make([]string, len(blindedMessages))
var overflows bool
for i, bm := range blindedMessages {
blindedMessagesAmount, overflows = overflowAddUint64(blindedMessagesAmount, bm.Amount)
if overflows {
return nil, cashu.InvalidBlindedMessageAmount
}
B_s[i] = bm.B_
}

fees := uint64(m.TransactionFees(proofs))
proofsMinusFees, underflow := underflowSubUint64(proofsAmount, fees)
proofsMinusFees, underflow := cashu.UnderflowSubUint64(proofsAmount, fees)
if underflow {
return nil, cashu.InvalidProofAmount
}
if proofsMinusFees < blindedMessagesAmount {
return nil, cashu.InsufficientProofsAmount
}

err := m.verifyProofs(proofs, Ys)
if err != nil {
if err := m.verifyProofs(proofs, Ys); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1464,22 +1460,6 @@ func (m *Mint) signBlindedMessages(blindedMessages cashu.BlindedMessages) (cashu
return blindedSignatures, nil
}

// overflowAddUint64 adds two uint64 and checks if that results in an overflow
func overflowAddUint64(a, b uint64) (uint64, bool) {
sum := a + b
if sum < a || sum < b {
return math.MaxUint64, true
}
return sum, false
}

func underflowSubUint64(a, b uint64) (uint64, bool) {
if b > a {
return 0, true
}
return a - b, false
}

// requestInvoice requests an invoice from the Lightning backend
// for the given amount
func (m *Mint) requestInvoice(amount uint64) (*lightning.Invoice, error) {
Expand Down
Loading

0 comments on commit 708d3ae

Please sign in to comment.