|
1 | 1 | import logging
|
2 | 2 | from typing import (
|
3 |
| - cast |
| 3 | + Callable, |
| 4 | + cast, |
4 | 5 | )
|
5 | 6 | from eth_utils import (
|
6 | 7 | ValidationError,
|
|
16 | 17 | )
|
17 | 18 |
|
18 | 19 |
|
| 20 | +def default_refund_strategy(meter: "GasMeter", amount: int) -> None: |
| 21 | + if amount < 0: |
| 22 | + raise ValidationError("Gas refund amount must be positive") |
| 23 | + |
| 24 | + meter.gas_refunded += amount |
| 25 | + |
| 26 | + meter.logger.trace( |
| 27 | + 'GAS REFUND: %s + %s -> %s', |
| 28 | + meter.gas_refunded - amount, |
| 29 | + amount, |
| 30 | + meter.gas_refunded, |
| 31 | + ) |
| 32 | + |
| 33 | + |
| 34 | +def allow_negative_refund_strategy(meter: "GasMeter", amount: int) -> None: |
| 35 | + meter.gas_refunded += amount |
| 36 | + |
| 37 | + meter.logger.trace( |
| 38 | + 'GAS REFUND: %s + %s -> %s', |
| 39 | + meter.gas_refunded - amount, |
| 40 | + amount, |
| 41 | + meter.gas_refunded, |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +RefundStrategy = Callable[["GasMeter", int], None] |
| 46 | + |
| 47 | + |
19 | 48 | class GasMeter(object):
|
| 49 | + |
20 | 50 | start_gas = None # type: int
|
21 | 51 |
|
22 | 52 | gas_refunded = None # type: int
|
23 | 53 | gas_remaining = None # type: int
|
24 | 54 |
|
25 | 55 | logger = cast(TraceLogger, logging.getLogger('eth.gas.GasMeter'))
|
26 | 56 |
|
27 |
| - def __init__(self, start_gas: int) -> None: |
| 57 | + def __init__(self, |
| 58 | + start_gas: int, |
| 59 | + refund_strategy: RefundStrategy = default_refund_strategy) -> None: |
28 | 60 | validate_uint256(start_gas, title="Start Gas")
|
29 | 61 |
|
| 62 | + self.refund_strategy = refund_strategy |
30 | 63 | self.start_gas = start_gas
|
31 | 64 |
|
32 | 65 | self.gas_remaining = self.start_gas
|
@@ -70,14 +103,4 @@ def return_gas(self, amount: int) -> None:
|
70 | 103 | )
|
71 | 104 |
|
72 | 105 | def refund_gas(self, amount: int) -> None:
|
73 |
| - if amount < 0: |
74 |
| - raise ValidationError("Gas refund amount must be positive") |
75 |
| - |
76 |
| - self.gas_refunded += amount |
77 |
| - |
78 |
| - self.logger.trace( |
79 |
| - 'GAS REFUND: %s + %s -> %s', |
80 |
| - self.gas_refunded - amount, |
81 |
| - amount, |
82 |
| - self.gas_refunded, |
83 |
| - ) |
| 106 | + return self.refund_strategy(self, amount) |
0 commit comments