From 9d1b5f52c689a16c622d39754b711693424d3c23 Mon Sep 17 00:00:00 2001 From: n8wb Date: Mon, 23 Dec 2024 14:31:01 -0800 Subject: [PATCH] expose sourceUsers --- kernels/batch.go | 13 +++++++--- kernels/batch_test.go | 60 ++++++++++++++++++++++++++++++------------- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/kernels/batch.go b/kernels/batch.go index 5586eee..3c6d0fb 100644 --- a/kernels/batch.go +++ b/kernels/batch.go @@ -15,7 +15,7 @@ type EarnRequestFullBatch struct { UserAddrs []string `json:"userAddrs"` Sources []string `json:"sources"` SubSources []string `json:"subSources"` - SourceUsers []string `json:"-"` + SourceUsers []string `json:"sourceUsers"` StartBlocks []int64 `json:"startBlocks"` StartTimes []int64 `json:"startTimes"` EarnRates []string `json:"earnRates"` @@ -42,15 +42,20 @@ func (e EarnRequestFullBatch) WithReferralBonuses(referralChains [][]string, tie out := e.Clone() - out.SourceUsers = make([]string, len(e.UserAddrs)) - copy(out.SourceUsers, e.UserAddrs) + if len(out.SourceUsers) == 0 { + return EarnRequestFullBatch{}, errors.New("sourceUsers must not be empty") + } for i := range referralChains { earnRate, ok := conversions.NewLargeFloat().SetString(e.EarnRates[i]) if !ok { return EarnRequestFullBatch{}, errors.New("invalid earn rate") } + if out.SourceUsers[i] != out.UserAddrs[i] || len(referralChains[i]) == 0 { + continue + } for j := range referralChains[i] { + out.UserAddrs = append(out.UserAddrs, referralChains[i][j]) out.Sources = append(out.Sources, out.Sources[i]) out.SubSources = append(out.SubSources, out.SubSources[i]) @@ -179,7 +184,7 @@ type EarnRequestBatch struct { UserAddrs []string `json:"userAddrs"` Source string `json:"source"` SubSource string `json:"subSource"` - SourceUsers []string `json:"-"` + SourceUsers []string `json:"sourceUsers"` StartBlock int64 `json:"startBlock"` StartTime int64 `json:"startTime"` EarnRates []string `json:"earnRates"` diff --git a/kernels/batch_test.go b/kernels/batch_test.go index cc9451f..ac312e6 100644 --- a/kernels/batch_test.go +++ b/kernels/batch_test.go @@ -9,38 +9,62 @@ import ( ) func Test_EarnRequestFullBatch_WithReferralBonuses(t *testing.T) { + users := []string{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()} req := EarnRequestFullBatch{ - UserAddrs: []string{testutils.GenRandEVMAddr()}, - Sources: []string{"source"}, - SubSources: []string{"subSource"}, - SourceUsers: nil, + UserAddrs: []string{users[0], users[1]}, + Sources: []string{"source", "source"}, + SubSources: []string{"subSource", "subSource"}, + SourceUsers: []string{users[0], testutils.GenRandEVMAddr()}, StartBlocks: nil, - StartTimes: []int64{1000}, - EarnRates: []string{"1000"}, + StartTimes: []int64{1000, 2000}, + EarnRates: []string{"1000", "2000"}, } - referralChains := [][]string{{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}} + referralChains := [][]string{ + {testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}, // This will not be ignored + {testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}} // Note we expect this to be ignored tierEarnRates := map[int]*big.Rat{0: big.NewRat(1, 2), 1: big.NewRat(1, 4)} result, err := req.WithReferralBonuses(referralChains, tierEarnRates) require.NoError(t, err) - require.Len(t, result.UserAddrs, 3) + require.Len(t, result.UserAddrs, 4) + + for i, source := range result.Sources { + if i == 1 { + require.Equal(t, req.Sources[i], source) + } else { + require.Equal(t, req.Sources[0], source) + } - for _, source := range result.Sources { - require.Equal(t, req.Sources[0], source) } - for _, subSource := range result.SubSources { - require.Equal(t, req.SubSources[0], subSource) + for i, subSource := range result.SubSources { + if i == 1 { + require.Equal(t, req.SubSources[i], subSource) + } else { + require.Equal(t, req.SubSources[0], subSource) + } + } - for _, sourceUser := range result.SourceUsers { - require.Equal(t, req.UserAddrs[0], sourceUser) + for i, sourceUser := range result.SourceUsers { + if i == 1 { + require.Equal(t, req.SourceUsers[i], sourceUser) + } else { + require.Equal(t, req.SourceUsers[0], sourceUser) + } + } - for _, startTime := range result.StartTimes { - require.Equal(t, req.StartTimes[0], startTime) + for i, startTime := range result.StartTimes { + if i == 1 { + require.Equal(t, req.StartTimes[i], startTime) + } else { + require.Equal(t, req.StartTimes[0], startTime) + } + } require.Nil(t, result.StartBlocks) + require.Len(t, result.EarnRates, 4) require.Equal(t, "1000", result.EarnRates[0]) - require.Equal(t, "500", result.EarnRates[1]) - require.Equal(t, "250", result.EarnRates[2]) + require.Equal(t, "500", result.EarnRates[2]) + require.Equal(t, "250", result.EarnRates[3]) }