Skip to content

Commit

Permalink
expose sourceUsers
Browse files Browse the repository at this point in the history
  • Loading branch information
n8wb committed Dec 23, 2024
1 parent 227534c commit 9d1b5f5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
13 changes: 9 additions & 4 deletions kernels/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type EarnRequestFullBatch struct {
UserAddrs []string `json:"userAddrs"`
Sources []string `json:"sources"`
SubSources []string `json:"subSources"`
SourceUsers []string `json:"-"`
SourceUsers []string `json:"sourceUsers"`
StartBlocks []int64 `json:"startBlocks"`
StartTimes []int64 `json:"startTimes"`
EarnRates []string `json:"earnRates"`
Expand All @@ -42,15 +42,20 @@ func (e EarnRequestFullBatch) WithReferralBonuses(referralChains [][]string, tie

out := e.Clone()

out.SourceUsers = make([]string, len(e.UserAddrs))
copy(out.SourceUsers, e.UserAddrs)
if len(out.SourceUsers) == 0 {
return EarnRequestFullBatch{}, errors.New("sourceUsers must not be empty")
}

for i := range referralChains {
earnRate, ok := conversions.NewLargeFloat().SetString(e.EarnRates[i])
if !ok {
return EarnRequestFullBatch{}, errors.New("invalid earn rate")
}
if out.SourceUsers[i] != out.UserAddrs[i] || len(referralChains[i]) == 0 {
continue
}
for j := range referralChains[i] {

out.UserAddrs = append(out.UserAddrs, referralChains[i][j])
out.Sources = append(out.Sources, out.Sources[i])
out.SubSources = append(out.SubSources, out.SubSources[i])
Expand Down Expand Up @@ -179,7 +184,7 @@ type EarnRequestBatch struct {
UserAddrs []string `json:"userAddrs"`
Source string `json:"source"`
SubSource string `json:"subSource"`
SourceUsers []string `json:"-"`
SourceUsers []string `json:"sourceUsers"`
StartBlock int64 `json:"startBlock"`
StartTime int64 `json:"startTime"`
EarnRates []string `json:"earnRates"`
Expand Down
60 changes: 42 additions & 18 deletions kernels/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,62 @@ import (
)

func Test_EarnRequestFullBatch_WithReferralBonuses(t *testing.T) {
users := []string{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}
req := EarnRequestFullBatch{
UserAddrs: []string{testutils.GenRandEVMAddr()},
Sources: []string{"source"},
SubSources: []string{"subSource"},
SourceUsers: nil,
UserAddrs: []string{users[0], users[1]},
Sources: []string{"source", "source"},
SubSources: []string{"subSource", "subSource"},
SourceUsers: []string{users[0], testutils.GenRandEVMAddr()},
StartBlocks: nil,
StartTimes: []int64{1000},
EarnRates: []string{"1000"},
StartTimes: []int64{1000, 2000},
EarnRates: []string{"1000", "2000"},
}

referralChains := [][]string{{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}}
referralChains := [][]string{
{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}, // This will not be ignored
{testutils.GenRandEVMAddr(), testutils.GenRandEVMAddr()}} // Note we expect this to be ignored
tierEarnRates := map[int]*big.Rat{0: big.NewRat(1, 2), 1: big.NewRat(1, 4)}

result, err := req.WithReferralBonuses(referralChains, tierEarnRates)
require.NoError(t, err)
require.Len(t, result.UserAddrs, 3)
require.Len(t, result.UserAddrs, 4)

for i, source := range result.Sources {
if i == 1 {
require.Equal(t, req.Sources[i], source)
} else {
require.Equal(t, req.Sources[0], source)
}

for _, source := range result.Sources {
require.Equal(t, req.Sources[0], source)
}
for _, subSource := range result.SubSources {
require.Equal(t, req.SubSources[0], subSource)
for i, subSource := range result.SubSources {
if i == 1 {
require.Equal(t, req.SubSources[i], subSource)
} else {
require.Equal(t, req.SubSources[0], subSource)
}

}
for _, sourceUser := range result.SourceUsers {
require.Equal(t, req.UserAddrs[0], sourceUser)
for i, sourceUser := range result.SourceUsers {
if i == 1 {
require.Equal(t, req.SourceUsers[i], sourceUser)
} else {
require.Equal(t, req.SourceUsers[0], sourceUser)
}

}
for _, startTime := range result.StartTimes {
require.Equal(t, req.StartTimes[0], startTime)
for i, startTime := range result.StartTimes {
if i == 1 {
require.Equal(t, req.StartTimes[i], startTime)
} else {
require.Equal(t, req.StartTimes[0], startTime)
}

}
require.Nil(t, result.StartBlocks)

require.Len(t, result.EarnRates, 4)
require.Equal(t, "1000", result.EarnRates[0])
require.Equal(t, "500", result.EarnRates[1])
require.Equal(t, "250", result.EarnRates[2])
require.Equal(t, "500", result.EarnRates[2])
require.Equal(t, "250", result.EarnRates[3])
}

0 comments on commit 9d1b5f5

Please sign in to comment.