Skip to content

Commit

Permalink
fix the bug discovered during shield code review (shentufoundation#563)
Browse files Browse the repository at this point in the history
* fix the bug discovered during shield code review

* add unit test for UpdatePool and purchase expiration handling
  • Loading branch information
haozhan9 authored Jan 12, 2023
1 parent 3d02406 commit b0a4b73
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 40 deletions.
122 changes: 95 additions & 27 deletions x/shield/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

tmproto "github.com/tendermint/tendermint/proto/tendermint/types"

sdksimapp "github.com/cosmos/cosmos-sdk/simapp"
sdk "github.com/cosmos/cosmos-sdk/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"

Expand Down Expand Up @@ -213,42 +212,27 @@ func TestWithdrawsByRedelegate(t *testing.T) {
func TestClaimProposal(t *testing.T) {
app := shentuapp.Setup(false)
ctx := app.BaseApp.NewContext(false, tmproto.Header{Time: time.Now().UTC()})

// set up testing helpers
tstaking := teststaking.NewHelper(t, ctx, app.StakingKeeper)
bondDenom := tstaking.Denom
tshield := testshield.NewHelper(t, ctx, app.ShieldKeeper, bondDenom)
tgov := testgov.NewHelper(t, ctx, app.GovKeeper, bondDenom)
// create and add addresses
pks := shentuapp.CreateTestPubKeys(5)
shentuapp.AddTestAddrsFromPubKeys(app, ctx, pks, sdk.ZeroInt())

shieldAdmin := sdk.AccAddress(pks[0].Address())
err := sdksimapp.FundAccount(app.BankKeeper, ctx, shieldAdmin, sdk.Coins{sdk.NewInt64Coin("uctk", 250e9)})
require.NoError(t, err)
shieldAdmin := tshield.GetFundedAcc(app.BankKeeper, pks[0], 250e9)
app.ShieldKeeper.SetAdmin(ctx, shieldAdmin)

sponsorAddr := sdk.AccAddress(pks[1].Address())
err = sdksimapp.FundAccount(app.BankKeeper, ctx, sponsorAddr, sdk.Coins{sdk.NewInt64Coin("uctk", 1)})
require.NoError(t, err)

purchaser := sdk.AccAddress(pks[2].Address())
err = sdksimapp.FundAccount(app.BankKeeper, ctx, purchaser, sdk.Coins{sdk.NewInt64Coin("uctk", 10e9)})
require.NoError(t, err)

del1addr := sdk.AccAddress(pks[3].Address())
err = sdksimapp.FundAccount(app.BankKeeper, ctx, del1addr, sdk.Coins{sdk.NewInt64Coin("uctk", 125e9)})
require.NoError(t, err)

sponsorAddr := tshield.GetFundedAcc(app.BankKeeper, pks[1], 1)
purchaser := tshield.GetFundedAcc(app.BankKeeper, pks[2], 10e9)
del1addr := tshield.GetFundedAcc(app.BankKeeper, pks[3], 125e9)
_ = tshield.GetFundedAcc(app.BankKeeper, pks[4], 100e6)
val1pk, val1addr := pks[4], sdk.ValAddress(pks[4].Address())
err = sdksimapp.FundAccount(app.BankKeeper, ctx, sdk.AccAddress(pks[4].Address()), sdk.Coins{sdk.NewInt64Coin("uctk", 100e6)})
require.NoError(t, err)

var adminDeposit int64 = 200e9
var delegatorDeposit int64 = 125e9
totalDeposit := adminDeposit + delegatorDeposit

// set up testing helpers
tstaking := teststaking.NewHelper(t, ctx, app.StakingKeeper)
bondDenom := tstaking.Denom
tshield := testshield.NewHelper(t, ctx, app.ShieldKeeper, bondDenom)
tgov := testgov.NewHelper(t, ctx, app.GovKeeper, bondDenom)

// set up a validator
tstaking.CreateValidatorWithValPower(val1addr, val1pk, 100, true)
ctx = nextBlock(ctx, tstaking, tshield, tgov)
Expand Down Expand Up @@ -324,7 +308,7 @@ func TestClaimProposal(t *testing.T) {

// create reimbursement
lossCoins := sdk.NewCoins(sdk.NewInt64Coin(bondDenom, loss))
err = app.ShieldKeeper.CreateReimbursement(ctx, proposalID, lossCoins, purchaser)
err := app.ShieldKeeper.CreateReimbursement(ctx, proposalID, lossCoins, purchaser)
require.NoError(t, err)
reimbursement, err := app.ShieldKeeper.GetReimbursement(ctx, proposalID)
require.NoError(t, err)
Expand Down Expand Up @@ -360,3 +344,87 @@ func TestClaimProposal(t *testing.T) {
afterInt := app.BankKeeper.GetBalance(ctx, purchaser, bondDenom).Amount
require.True(t, beforeInt.Add(sdk.NewInt(loss)).Equal(afterInt))
}

func TestUpdatePool(t *testing.T) {
app := shentuapp.Setup(false)
ctx := app.BaseApp.NewContext(false, tmproto.Header{Time: time.Now().UTC()})
// set up testing helpers
tstaking := teststaking.NewHelper(t, ctx, app.StakingKeeper)
tshield := testshield.NewHelper(t, ctx, app.ShieldKeeper, tstaking.Denom)
tgov := testgov.NewHelper(t, ctx, app.GovKeeper, tstaking.Denom)
// create and add addresses
pks := shentuapp.CreateTestPubKeys(4)
shentuapp.AddTestAddrsFromPubKeys(app, ctx, pks, sdk.ZeroInt())

orginalFund := int64(250e9)
shieldAdmin := tshield.GetFundedAcc(app.BankKeeper, pks[0], orginalFund)
app.ShieldKeeper.SetAdmin(ctx, shieldAdmin)
sponsorAddr := tshield.GetFundedAcc(app.BankKeeper, pks[1], 1)
purchaserAddr := tshield.GetFundedAcc(app.BankKeeper, pks[2], 125e9)
_ = tshield.GetFundedAcc(app.BankKeeper, pks[3], 100e6)
val1pk, val1addr := pks[3], sdk.ValAddress(pks[3].Address())

// set up a validator
tstaking.CreateValidatorWithValPower(val1addr, val1pk, 100, true)
ctx = nextBlock(ctx, tstaking, tshield, tgov)
tstaking.CheckValidator(val1addr, stakingtypes.Bonded, false)

// 1)delegate tokens
// 2)deposite collateral as a shield provider
// 3)create pool
adminDeposit := int64(200e9)
serviceFee0 := int64(200e6)
shield1 := int64(50e9)
tstaking.Delegate(shieldAdmin, val1addr, adminDeposit)
tshield.DepositCollateral(shieldAdmin, adminDeposit, true)
tshield.CreatePool(shieldAdmin, sponsorAddr, serviceFee0, shield1, 500e9, "Shentu", "fake_description")
pools := app.ShieldKeeper.GetAllPools(ctx)
require.True(t, len(pools) == 1)
require.True(t, strAddrEqualsAccAddr(pools[0].SponsorAddr, sponsorAddr))

//update the pool with shield purchasement
serviceFee1, shield2 := int64(20000), int64(30e9)
tshield.UpdatePool(pools[0].Id, shieldAdmin, serviceFee1, shield2, 0, "updatepool1")
require.True(t,
app.ShieldKeeper.GetServiceFees(ctx).IsEqual(
tshield.DecCoinsI64(serviceFee0+serviceFee1)))
shieldAdminBalance := app.BankKeeper.GetBalance(ctx, shieldAdmin, tstaking.Denom).Amount.Int64()
require.True(t, shieldAdminBalance == orginalFund-adminDeposit-serviceFee0-serviceFee1)
purchases := app.ShieldKeeper.GetAllPurchases(ctx)
require.True(t, len(purchases) == 2)
require.True(t, purchases[0].Shield.Int64() == shield1)
require.True(t, purchases[1].Shield.Int64() == shield2)

//update the pool without shield purchasement, but with service fees payment
serviceFee2 := int64(7e9)
tshield.UpdatePool(pools[0].Id, shieldAdmin, serviceFee2, 0, 0, "updatepool2")
shieldAdminBalance = app.BankKeeper.GetBalance(ctx, shieldAdmin, tstaking.Denom).Amount.Int64()
require.True(t, shieldAdminBalance == orginalFund-adminDeposit-serviceFee0-serviceFee1-serviceFee2)
purchases = app.ShieldKeeper.GetAllPurchases(ctx)
require.True(t, len(purchases) == 2)

// 1)stake for shield
// 2)pass the purchase's protection end time
// 3)check the newly created staked purchase
shield3 := int64(7e9)
tshield.StakeForShield(pools[0].Id, shield3, "shield created by staking", purchaserAddr)
purchases = app.ShieldKeeper.GetAllPurchases(ctx)
require.True(t, len(purchases) == 3)
require.True(t, purchases[0].Description == "shield created by staking")
stakedPurchaseId1 := purchases[0].PurchaseId
allStakes := app.ShieldKeeper.GetAllOriginalStakings(ctx)
require.True(t, len(allStakes) == 1)
stakeAmt1 := app.ShieldKeeper.GetOriginalStaking(ctx, stakedPurchaseId1)
protectionSecs := int64(app.ShieldKeeper.GetPoolParams(ctx).ProtectionPeriod.Seconds())
numBlocks := protectionSecs/5 + 1
ctx = skipBlocks(ctx, numBlocks, tstaking, tshield, tgov)
allStakes = app.ShieldKeeper.GetAllOriginalStakings(ctx)
require.True(t, len(allStakes) == 1)
purchases = app.ShieldKeeper.GetAllPurchases(ctx)
require.True(t, len(purchases) == 1)
stakedPurchaseId2 := purchases[0].PurchaseId
require.False(t, stakedPurchaseId1 == stakedPurchaseId2)
stakeAmt2 := app.ShieldKeeper.GetOriginalStaking(ctx, stakedPurchaseId2)
//the two stakes are both calculated the same way, they should equal
require.True(t, stakeAmt1.Equal(stakeAmt2))
}
3 changes: 3 additions & 0 deletions x/shield/keeper/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ func (k Keeper) UpdatePool(ctx sdk.Context, poolID uint64, description string, u
}
} else if !serviceFees.IsZero() {
// Allow adding service fees without purchasing more shield.
if err := k.bk.SendCoinsFromAccountToModule(ctx, updater, types.ModuleName, serviceFees); err != nil {
return pool, err
}
totalServiceFees := k.GetServiceFees(ctx)
totalServiceFees = totalServiceFees.Add(sdk.NewDecCoinsFromCoins(serviceFees...)...)
k.SetServiceFees(ctx, totalServiceFees)
Expand Down
2 changes: 1 addition & 1 deletion x/shield/keeper/proposal.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func (k Keeper) UpdateProviderCollateralForPayout(ctx sdk.Context, providerAddr
payoutFromWithdraw := payout.Sub(payoutFromCollateral)

// Update provider's collateral and total withdraw.
provider.Collateral = provider.Collateral.Sub(payout)
provider.Collateral = provider.Collateral.Sub(payoutFromCollateral)
provider.Withdrawing = provider.Withdrawing.Sub(payoutFromWithdraw)
totalWithdrawing = totalWithdrawing.Sub(payoutFromWithdraw)

Expand Down
22 changes: 10 additions & 12 deletions x/shield/keeper/purchase.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,15 @@ func (k Keeper) RemoveExpiredPurchasesAndDistributeFees(ctx sdk.Context) {
totalServiceFees = totalServiceFees.Sub(entry.ServiceFees)
// Set purchase fees to zero because it can be reached again.
purchaseList.Entries[i].ServiceFees = sdk.DecCoins{}

originalStaking := k.GetOriginalStaking(ctx, entry.PurchaseId)
if !originalStaking.IsZero() {
// keep track of the list to be updated to avoid overwriting the purchase list
stakeForShieldUpdateList = append(stakeForShieldUpdateList, pPPTriplet{
poolID: poolPurchaser.PoolId,
purchaseID: entry.PurchaseId,
purchaser: purchaser,
})
}
}
originalStaking := k.GetOriginalStaking(ctx, entry.PurchaseId)
if !originalStaking.IsZero() {
// keep track of the list to be updated to avoid overwriting the purchase list
stakeForShieldUpdateList = append(stakeForShieldUpdateList, pPPTriplet{
poolID: poolPurchaser.PoolId,
purchaseID: entry.PurchaseId,
purchaser: purchaser,
})
}

// If purchaseDeletionTime < currentBlockTime, remove the purchase.
Expand Down Expand Up @@ -295,6 +294,7 @@ func (k Keeper) RemoveExpiredPurchasesAndDistributeFees(ctx sdk.Context) {

// Add block service fees that need to be distributed for this block
blockServiceFees := k.GetBlockServiceFees(ctx)
remainingServiceFees = remainingServiceFees.Add(blockServiceFees...)
serviceFees = serviceFees.Add(blockServiceFees...)
k.DeleteBlockServiceFees(ctx)

Expand All @@ -318,8 +318,6 @@ func (k Keeper) RemoveExpiredPurchasesAndDistributeFees(ctx sdk.Context) {
remainingServiceFees = remainingServiceFees.Sub(newFees)
}

// add back block fees
remainingServiceFees = remainingServiceFees.Add(blockServiceFees...)
k.SetRemainingServiceFees(ctx, remainingServiceFees)
k.SetLastUpdateTime(ctx, ctx.BlockTime())
}
Expand Down
28 changes: 28 additions & 0 deletions x/shield/testshield/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (

"github.com/stretchr/testify/require"

cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdksimapp "github.com/cosmos/cosmos-sdk/simapp"
sdk "github.com/cosmos/cosmos-sdk/types"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
bankkeeper "github.com/shentufoundation/shentu/v2/x/bank/keeper"

"github.com/shentufoundation/shentu/v2/x/shield"
"github.com/shentufoundation/shentu/v2/x/shield/keeper"
Expand Down Expand Up @@ -95,3 +98,28 @@ func (sh *Helper) HandleProposal(content govtypes.Content, ok bool) {
require.Error(sh.t, err)
}
}

func (sh *Helper) GetFundedAcc(bk bankkeeper.Keeper, pk cryptotypes.PubKey, amt int64) sdk.AccAddress {
accAdd := sdk.AccAddress(pk.Address())
err := sdksimapp.FundAccount(bk, sh.ctx, accAdd, sdk.Coins{sdk.NewInt64Coin(sh.denom, amt)})
require.NoError(sh.t, err)
return accAdd
}

func (sh *Helper) UpdatePool(poolID uint64, fromAddr sdk.AccAddress, serviceFee, shield, shieldLimit int64, desc string) {
shieldCoins := sdk.NewCoins(sdk.NewInt64Coin(sh.denom, shield))
serviceFeeCoins := sdk.NewCoins(sdk.NewInt64Coin(sh.denom, serviceFee))
limit := sdk.NewInt(shieldLimit)
msg := types.NewMsgUpdatePool(fromAddr, shieldCoins, serviceFeeCoins, poolID, desc, limit)
sh.Handle(msg, true)
}

func (sh *Helper) StakeForShield(poolID uint64, shield int64, desc string, from sdk.AccAddress) {
shieldCoins := sdk.NewCoins(sdk.NewInt64Coin(sh.denom, shield))
msg := types.NewMsgStakeForShield(poolID, shieldCoins, desc, from)
sh.Handle(msg, true)
}

func (sh *Helper) DecCoinsI64(amt int64) sdk.DecCoins {
return sdk.DecCoins{sdk.NewInt64DecCoin(sh.denom, amt)}
}

0 comments on commit b0a4b73

Please sign in to comment.