diff --git a/.circleci/config.yml b/.circleci/config.yml index f56044b..e069cc1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,6 +20,7 @@ jobs: - run: command: | . venv/bin/activate + flake8 --ignore E501 casper/ casper.py pytest - store_artifacts: path: test-reports/ diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..e4d02db --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,11 @@ +# Contributing +Hello fellow contributor! This document will provide an overview of the process for contributing to the CBC casper codebase! Please note that these guidelines may be changed in the future. + +# Getting Started +To get started with the codebase, see the download and installation instructions in the readme. + +# Making a Pull Request +Once you have implemented your proposed changes, create a Pull Request to the `develop` branch. Ensure there are no merge conflicts with your branch and that all tests pass. + +# Resources +See the [wiki](https://github.com/ethereum/cbc-casper/wiki) for a collection of resources related to the project. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..4d08b74 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +Thank you for contributing to CBC Casper! + +Please make your Pull Request against the `develop` branch and ensure that all tests pass. + +Unless fixing a typo or other small issue, PRs should link to specific issues. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index a4ad495..0000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,15 +0,0 @@ -# Contributing -Hello fellow contributor! This document will provide an overview of the process for contributing to the CBC casper codebase! Please note that these guidelines may be changed in the future. - -# Getting Started -To get started with the codebase, see the download and installation instructions in the readme. - -# Making a Pull Request -Once you have implemented your proposed changes, you need to create a pull request. Before making a PR, make sure there are no merge conflicts with your branch and that all tests pass. - -TODO: Add notes about licensing and rights. - -# Resources -Consensus protocols are notorious for being incredibly hard to understand. Luckily, CBC Casper's spec is much easier to consume than most other consensus protocols that exist today (a result of the correct by construction process). That being said, here are some resources that may be helpful in getting started understanding the protocol. - -TODO: Add Casper papers, add Vlad's talk. \ No newline at end of file diff --git a/README.md b/README.md index 0f63f20..495c1d2 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ pip install using `requirements.txt` ### Standard Standard simulations are marked up for use as follows: -NOTE: after each viewgraph appears, you must manually exit the window for the simulation to continue to run! + ``` make run-[rand | rrob | full | nofinal | binary] ``` @@ -54,7 +54,7 @@ make run-[rand | rrob | full | nofinal | binary] `binary:` unlike the above message propagation schemes, this changes the protocol to cbc-casper with binary data structures! Instead of a blockchain, this protocol just comes to consensus on a single bit. -The number of validators, the number of messages that propagate per round, and the report interval can be edited in `casper/settings.py`. +By default, a gif and associated images of the simulation will be saved in `graphs/graph_num_0/`. These settings can be modified, along with the number of validators, the number of messages that propagate per round, and the report interval in the `config.ini`. ### Advanced Advanced simulations can be run with a little command line wizardy. @@ -99,7 +99,10 @@ The following are the fields that make up an experiment to be defined in a `.jso available schemes are "rand", "rrob", "full", and "nofinal". `protocol` (string): Specifies the protocol to test. Available protocols are -"blockchain" and "binary" (for now!). +"blockchain", "integer", and "binary". + +`network` (string): Specifies the network model test. Available networks are +"no-delay", "constant", "linear", and "gaussian". `num_simulations` (number): Specifies the number of simulations to run. Each simulation starts with a fresh setup -- messages, validators, etc. diff --git a/casper.py b/casper.py index 1ab87b9..d027428 100644 --- a/casper.py +++ b/casper.py @@ -14,12 +14,23 @@ from simulations.utils import ( generate_random_gaussian_validator_set, message_maker, + select_network, select_protocol, MESSAGE_MODES, + NETWORKS, PROTOCOLS ) +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + def default_configuration(): config = ConfigParser() config.read("config.ini") @@ -39,6 +50,11 @@ def main(): choices=PROTOCOLS, help='specifies the protocol for the simulation' ) + parser.add_argument( + '--network', type=str, default=config.get("DefaultNetwork"), + choices=NETWORKS, + help='specifies the network model for the simulation' + ) parser.add_argument( '--validators', type=int, default=config.getint("NumValidators"), help='specifies the number of validators in validator set' @@ -52,30 +68,38 @@ def main(): help='specifies the interval in rounds at which to plot results' ) parser.add_argument( - '--hide-display', help='hide simulation display', action='store_true' + '--hide-display', action="store_true", + help='display simulations round by round' ) parser.add_argument( - '--save', help='hide simulation display', action='store_true' + '--save', type=str2bool, default=config.getboolean("Save"), + help='save the simulation in graphs/ directory' + ) + parser.add_argument( + '--justify-messages', type=str2bool, default=config.getboolean("JustifyMessages"), + help='force full propagation of all messages in justification of message when sending' ) args = parser.parse_args() protocol = select_protocol(args.protocol) + network_type = select_network(args.network) validator_set = generate_random_gaussian_validator_set( protocol, args.validators ) + network = network_type(validator_set, protocol) msg_gen = message_maker(args.mode) - display = not args.hide_display simulation_runner = SimulationRunner( validator_set, msg_gen, - protocol, + protocol=protocol, + network=network, total_rounds=args.rounds, report_interval=args.report_interval, - display=display, + display=(not args.hide_display), save=args.save, ) simulation_runner.run() diff --git a/casper/abstract_view.py b/casper/abstract_view.py index 02b0e55..fe78488 100644 --- a/casper/abstract_view.py +++ b/casper/abstract_view.py @@ -1,5 +1,4 @@ -"""The view module ... """ -from casper.justification import Justification +"""The parent view module that specific protocol views inherit from""" class AbstractView(object): @@ -9,68 +8,98 @@ def __init__(self, messages=None): if messages is None: messages = set() - self.add_messages(messages) - - self.messages = set() - self.latest_messages = dict() - - def __str__(self): - output = "View: \n" - for bet in self.messages: - output += str(bet) + "\n" - return output - - def justification(self): - """Returns the latest messages seen from other validators, to justify estimate.""" - return Justification(self.latest_messages) - - def get_new_messages(self, showed_messages): - """This method returns the set of messages out of showed_messages - and their dependency that isn't part of the view.""" - - new_messages = set() - # The memo will keep track of messages we've already looked at, so we don't redo work. - memo = set() - - # At the start, our working set will be the "showed messages" parameter. - current_set = set(showed_messages) - while current_set != set(): - - next_set = set() - # If there's no message in the current working set. - for message in current_set: - - # Which we haven't seen it in the view or during this loop. - if message not in self.messages and message not in memo: - - # But if we do have a new message, then we add it to our pile.. - new_messages.add(message) + self.justified_messages = dict() # message hash => message + self.pending_messages = dict() # message hash => message - # and add the bet in its justification to our next working set - for bet in message.justification.latest_messages.values(): - next_set.add(bet) - # Keeping a record of very message we inspect, being sure not - # to do any extra (exponential complexity) work. - memo.add(message) + self.num_missing_dependencies = dict() # message hash => number of message hashes + self.dependents_of_message = dict() # message hash => list(message hashes) - current_set = next_set + self.latest_messages = dict() # validator => message - # After the loop is done, we return a set of new messages. - return new_messages + self.add_messages(messages) def estimate(self): '''Must be defined in child class. Returns estimate based on current messages in the view''' - pass + raise NotImplementedError - def add_messages(self, showed_messages): + def update_safe_estimates(self, validator_set): '''Must be defined in child class.''' - pass + raise NotImplementedError - def make_new_message(self, validator): - '''Must be defined in child class.''' + def add_messages(self, showed_messages): + """Adds a set of newly received messages to pending or justified""" + for message in showed_messages: + if message.hash in self.pending_messages or message.hash in self.justified_messages: + continue + + missing_message_hashes = self._missing_messages_in_justification(message) + if not any(missing_message_hashes): + self.receive_justified_message(message) + else: + self.receive_pending_message(message, missing_message_hashes) + + def receive_justified_message(self, message): + """Upon receiving a justified message, resolves waiting messages and adds to view""" + newly_justified_messages = self.get_newly_justified_messages(message) + + for justified_message in newly_justified_messages: + self._add_to_latest_messages(justified_message) + self._add_justified_remove_pending(justified_message) + self._update_protocol_specific_view(justified_message) + + def receive_pending_message(self, message, missing_message_hashes): + """Updates and stores pending messages and dependencies""" + self.pending_messages[message.hash] = message + self.num_missing_dependencies[message.hash] = len(missing_message_hashes) + + for missing_message_hash in missing_message_hashes: + if missing_message_hash not in self.dependents_of_message: + self.dependents_of_message[missing_message_hash] = [] + + self.dependents_of_message[missing_message_hash].append(message.hash) + + def get_newly_justified_messages(self, message): + """Given a new justified message, get all messages that are now justified + due to its receipt""" + newly_justified_messages = set([message]) + + for dependent_hash in self.dependents_of_message.get(message.hash, set()): + self.num_missing_dependencies[dependent_hash] -= 1 + + if self.num_missing_dependencies[dependent_hash] == 0: + new_message = self.pending_messages[dependent_hash] + newly_justified_messages.update(self.get_newly_justified_messages(new_message)) + + return newly_justified_messages + + def _update_protocol_specific_view(self, message): + """ Can be implemented by child, though not necessary + Updates a view's specific info, given a justified message""" pass - def update_safe_estimates(self, validator_set): - '''Must be defined in child class.''' - pass + def _add_to_latest_messages(self, message): + """Updates a views most recent messages, if this message is later""" + if message.sender not in self.latest_messages: + self.latest_messages[message.sender] = message + elif self.latest_messages[message.sender].sequence_number < message.sequence_number: + self.latest_messages[message.sender] = message + + def _add_justified_remove_pending(self, message): + """Atomic action that: + - removes all data related to tracking the not yet justified message + - adds message to justified dict""" + self.justified_messages[message.hash] = message + if message.hash in self.num_missing_dependencies: + del self.num_missing_dependencies[message.hash] + if message.hash in self.dependents_of_message: + del self.dependents_of_message[message.hash] + if message.hash in self.pending_messages: + del self.pending_messages[message.hash] + + def _missing_messages_in_justification(self, message): + """Returns the set of not seen messages hashes from the justification of a message""" + return { + message_hash for message_hash in message.justification.values() + if message_hash not in self.justified_messages + } diff --git a/casper/binary/binary_protocol.py b/casper/binary/binary_protocol.py deleted file mode 100644 index 30af95c..0000000 --- a/casper/binary/binary_protocol.py +++ /dev/null @@ -1,10 +0,0 @@ -from casper.binary.binary_view import BinaryView -from casper.binary.bet import Bet -from casper.binary.binary_plot_tool import BinaryPlotTool -from casper.protocol import Protocol - - -class BinaryProtocol(Protocol): - View = BinaryView - Message = Bet - PlotTool = BinaryPlotTool diff --git a/casper/binary/binary_view.py b/casper/binary/binary_view.py deleted file mode 100644 index 17a447d..0000000 --- a/casper/binary/binary_view.py +++ /dev/null @@ -1,68 +0,0 @@ -"""The blockchain view module extends a view for blockchain data structures """ -from casper.safety_oracles.clique_oracle import CliqueOracle -from casper.abstract_view import AbstractView -from casper.binary.bet import Bet -import casper.binary.binary_estimator as estimator - -import random as r - - -class BinaryView(AbstractView): - """A view class that also keeps track of a last_finalized_block and children""" - def __init__(self, messages=None): - super().__init__(messages) - - self.last_finalized_estimate = None - - def estimate(self): - """Returns the current forkchoice in this view""" - return estimator.get_estimate_from_latest_messages( - self.latest_messages - ) - - def add_messages(self, showed_messages): - """Updates views latest_messages and children based on new messages""" - - if not showed_messages: - return - - for message in showed_messages: - assert isinstance(message, Bet), "expected only to add a Bet!" - - # find any not-seen messages - newly_discovered_messages = self.get_new_messages(showed_messages) - - # add these new messages to the messages in view - self.messages.update(newly_discovered_messages) - - # update views most recently seen messages - for message in newly_discovered_messages: - if message.sender not in self.latest_messages: - self.latest_messages[message.sender] = message - elif self.latest_messages[message.sender].sequence_number < message.sequence_number: - self.latest_messages[message.sender] = message - - def make_new_message(self, validator): - """Make a new bet!""" - justification = self.justification() - estimate = self.estimate() - if not any(self.messages): - estimate = r.randint(0, 1) - - new_message = Bet(estimate, justification, validator) - self.add_messages(set([new_message])) - - return new_message - - def update_safe_estimates(self, validator_set): - """Checks safety on most recent created by this view""" - # check estimate safety on the most - for bet in self.latest_messages.values(): - oracle = CliqueOracle(bet, self, validator_set) - fault_tolerance, _ = oracle.check_estimate_safety() - - if fault_tolerance > 0: - if self.last_finalized_estimate: - assert not self.last_finalized_estimate.conflicts_with(bet) - self.last_finalized_estimate = bet - break diff --git a/casper/blockchain/blockchain_protocol.py b/casper/blockchain/blockchain_protocol.py deleted file mode 100644 index e2ab39d..0000000 --- a/casper/blockchain/blockchain_protocol.py +++ /dev/null @@ -1,9 +0,0 @@ -from casper.blockchain.blockchain_view import BlockchainView -from casper.blockchain.block import Block -from casper.blockchain.blockchain_plot_tool import BlockchainPlotTool -from casper.protocol import Protocol - -class BlockchainProtocol(Protocol): - View = BlockchainView - Message = Block - PlotTool = BlockchainPlotTool diff --git a/casper/blockchain/blockchain_view.py b/casper/blockchain/blockchain_view.py deleted file mode 100644 index 27d1296..0000000 --- a/casper/blockchain/blockchain_view.py +++ /dev/null @@ -1,93 +0,0 @@ -"""The blockchain view module extends a view for blockchain data structures """ -from casper.safety_oracles.clique_oracle import CliqueOracle -from casper.abstract_view import AbstractView -from casper.blockchain.block import Block -import casper.blockchain.forkchoice as forkchoice - - -class BlockchainView(AbstractView): - """A view class that also keeps track of a last_finalized_block and children""" - def __init__(self, messages=None): - super().__init__(messages) - - self.children = dict() - self.last_finalized_block = None - - # cache info about message events - self.when_added = {} - for message in self.messages: - self.when_added[message] = 0 - self.when_finalized = {} - - def estimate(self): - """Returns the current forkchoice in this view""" - return forkchoice.get_fork_choice( - self.last_finalized_block, - self.children, - self.latest_messages - ) - - def add_messages(self, showed_messages): - """Updates views latest_messages and children based on new messages""" - - if not showed_messages: - return - - for message in showed_messages: - assert isinstance(message, Block), "expected only to add a block!" - - # find any not-seen messages - newly_discovered_messages = self.get_new_messages(showed_messages) - - # add these new messages to the messages in view - self.messages.update(newly_discovered_messages) - - for message in newly_discovered_messages: - # update views most recently seen messages - if message.sender not in self.latest_messages: - self.latest_messages[message.sender] = message - elif self.latest_messages[message.sender].sequence_number < message.sequence_number: - self.latest_messages[message.sender] = message - - # update the children dictonary with the new message - if message.estimate not in self.children: - self.children[message.estimate] = set() - self.children[message.estimate].add(message) - - # update when_added cache - if message not in self.when_added: - self.when_added[message] = len(self.messages) - - def make_new_message(self, validator): - justification = self.justification() - estimate = self.estimate() - - new_message = Block(estimate, justification, validator) - self.add_messages(set([new_message])) - - return new_message - - def update_safe_estimates(self, validator_set): - """Checks safety on messages in views forkchoice, and updates last_finalized_block""" - tip = self.estimate() - - prev_last_finalized_block = self.last_finalized_block - - while tip and tip != prev_last_finalized_block: - oracle = CliqueOracle(tip, self, validator_set) - fault_tolerance, _ = oracle.check_estimate_safety() - - if fault_tolerance > 0: - self.last_finalized_block = tip - # then, a sanity check! - if prev_last_finalized_block: - assert prev_last_finalized_block.is_in_blockchain(self.last_finalized_block) - - # cache when_finalized - while tip and tip not in self.when_finalized: - self.when_finalized[tip] = len(self.messages) - tip = tip.estimate - - return self.last_finalized_block - - tip = tip.estimate diff --git a/casper/justification.py b/casper/justification.py deleted file mode 100644 index 979ce20..0000000 --- a/casper/justification.py +++ /dev/null @@ -1,11 +0,0 @@ -"""The justification module ...""" - - -class Justification(object): - """The justification class ...""" - def __init__(self, latest_messages=None): - if latest_messages is None: - latest_messages = {} - self.latest_messages = dict() - for validator in latest_messages: - self.latest_messages[validator] = latest_messages[validator] diff --git a/casper/message.py b/casper/message.py index 74fe2b2..8d2eae9 100644 --- a/casper/message.py +++ b/casper/message.py @@ -1,46 +1,61 @@ """The message module defines an abstract message class """ import random as r -from casper.justification import Justification class Message(object): """Message/bet data structure for blockchain consensus""" - def __eq__(self, message): - if message is None: - return False - return self.__hash__() == message.__hash__() + def __init__(self, estimate, justification, sender, sequence_number, display_height): + if not self.is_valid_estimate(estimate): + raise TypeError("Estimate {} is invalid!".format(estimate)) - def __ne__(self, message): - return not self.__eq__(message) - - def __init__(self, estimate, justification, sender): - assert isinstance(justification, Justification), "justification should be a Justification!" + assert isinstance(justification, dict), "expected justification a Justification!" self.sender = sender self.estimate = estimate self.justification = justification + self.sequence_number = sequence_number + self.display_height = display_height + self.header = r.random() - if self.sender in self.justification.latest_messages: - latest_message = self.justification.latest_messages[self.sender] - self.sequence_number = latest_message.sequence_number + 1 - else: - self.sequence_number = 0 + def __hash__(self): + # defined differently than self.hash to avoid confusion with builtin + # use of __hash__ in dictionaries, sets, etc + return hash(self.hash) - # The "display_height" of bets are used for visualization of views - if not any(self.justification.latest_messages): - self.display_height = 0 - else: - max_height = max( - self.justification.latest_messages[validator].display_height - for validator in self.justification.latest_messages - ) - self.display_height = max_height + 1 + def __eq__(self, message): + if not isinstance(message, Message): + return False + return self.hash == message.hash - self.salt = r.randint(0, 1000000) + def __lt__(self, message): + if not isinstance(message, Message): + return False + return self.hash < message.hash - def __hash__(self): - return hash(str(self.sender.name) + str(self.sequence_number) + str(self.salt)) + def __le__(self, message): + if not isinstance(message, Message): + return False + return self.hash <= message.hash + + def __gt__(self, message): + if not isinstance(message, Message): + return False + return self.hash > message.hash + + def __ge__(self, message): + if not isinstance(message, Message): + return False + return self.hash >= message.hash + + @property + def hash(self): + return hash(str(self.header)) + + @classmethod + def is_valid_estimate(cls, estimate): + '''Must be implemented by child class''' + raise NotImplementedError def conflicts_with(self, message): '''Must be implemented by child class''' - pass + raise NotImplementedError diff --git a/casper/network.py b/casper/network.py index 97da966..ea6c846 100644 --- a/casper/network.py +++ b/casper/network.py @@ -1,44 +1,85 @@ """The network module contains a network class allowing for message passing """ -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from utils.priority_queue import PriorityQueue +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol class Network(object): """Simulates a network that allows for message passing between validators.""" def __init__(self, validator_set, protocol=BlockchainProtocol): self.validator_set = validator_set - self.global_view = protocol.View(set()) + self.global_view = protocol.View( + self._collect_initial_messages(), + protocol.initial_message(None) + ) + self.message_queues = { + validator: PriorityQueue() + for validator in self.validator_set + } + self._current_time = 0 - def propagate_message_to_validator(self, message, validator): - """Propagate a message to a validator.""" - assert message in self.global_view.messages, ("...expected only to propagate messages " - "from the global view") - assert validator in self.validator_set, "...expected a known validator" + def delay(self, sender, receiver): + '''Must be defined in child class. + Returns delay of next message for sender to receiver''' + raise NotImplementedError - validator.receive_messages(set([message])) + # + # Network Time API + # Example: SimulationRunner + # Base model comes with vary simple notion of time advanced forward in clicks by `advance_time` + # + # For more advanced usage, override `time` property with a real clock + # and disregard `advance_time` + # + @property + def time(self): + return self._current_time - def get_message_from_validator(self, validator): - """Get a message from a validator.""" - assert validator in self.validator_set, "...expected a known validator" + def advance_time(self, amount=1): + self._current_time += amount - new_message = validator.make_new_message() - self.global_view.add_messages(set([new_message])) + # + # Validator API to Network + # + def send(self, validator, message): + self.global_view.add_messages( + set([message]) + ) + self.message_queues[validator].put(( + self.time + self.delay(message.sender, validator), + message + )) - return new_message + def send_to_all(self, message): + for validator in self.validator_set: + if validator == message.sender: + continue + self.send(validator, message) + + def receive(self, validator): + queue = self.message_queues[validator] + if queue.qsize() == 0: + return None + if queue.peek()[0] > self.time: + return None - def view_initialization(self, view): - """ - Initalizes all validators with all messages in some view. - NOTE: This method is not currently tested or called anywhere in repo - """ - self.global_view = view.messages + return queue.get()[1] - latest = view.latest_messages + def receive_all_available(self, validator): + messages = [] + message = self.receive(validator) + while message: + messages.append(message) + message = self.receive(validator) - for validator in latest: - validator.receive_messages(set([latest[validator]])) + return messages + + # + # helpers + # + def _collect_initial_messages(self): + initial_messages = set() - def random_initialization(self): - """Generates starting messages for all validators with None as an estiamte.""" for validator in self.validator_set: - new_bet = self.get_message_from_validator(validator) - self.global_view.add_messages(set([new_bet])) + initial_messages.update(validator.view.justified_messages.values()) + + return initial_messages diff --git a/casper/networks.py b/casper/networks.py new file mode 100644 index 0000000..c8912df --- /dev/null +++ b/casper/networks.py @@ -0,0 +1,36 @@ +import random as r + +from casper.network import Network + + +class NoDelayNetwork(Network): + def delay(self, sender, receiver): + return 0 + + +class ConstantDelayNetwork(Network): + CONSTANT = 5 + + def delay(self, sender, receiver): + return self.CONSTANT + + +class StepNetwork(ConstantDelayNetwork): + CONSTANT = 1 + + +class LinearDelayNetwork(Network): + MAX_DELAY = 5 + + def delay(self, sender, receiver): + return r.randint(1, self.MAX_DELAY) + + +class GaussianDelayNetwork(Network): + MU = 10 + SIGMA = 5 + MIN_DELAY = 1 + + def delay(self, sender, receiver): + random_delay = round(r.gauss(self.MU, self.SIGMA)) + return max(self.MIN_DELAY, random_delay) diff --git a/casper/plot_tool.py b/casper/plot_tool.py index 135df15..c58471b 100644 --- a/casper/plot_tool.py +++ b/casper/plot_tool.py @@ -32,7 +32,6 @@ def __init__(self, display, save, node_shape): self.report_number = 0 - def _create_graph_folder(self): graph_path = os.path.dirname(os.path.abspath(__file__)) + '/../graphs/' # if there isn't a graph folder, make one! @@ -42,7 +41,7 @@ def _create_graph_folder(self): # find the next name for the next plot! graph_num = 0 while True: - new_plot = graph_path + 'graph_num_' + str(graph_num) + new_plot = graph_path + 'graph_num_' + str(graph_num).zfill(3) graph_num += 1 if not os.path.isdir(new_plot): os.makedirs(new_plot) @@ -57,7 +56,7 @@ def build_viewgraph(self, view, validator_set, message_colors, message_labels, e graph = nx.Graph() - nodes = view.messages + nodes = view.justified_messages.values() fig_size = plt.rcParams["figure.figsize"] fig_size[0] = 20 @@ -70,7 +69,7 @@ def build_viewgraph(self, view, validator_set, message_colors, message_labels, e edge = [] if edges == []: for message in nodes: - for msg_in_justification in message.justification.latest_messages.values(): + for msg_in_justification in message.justification.values(): if msg_in_justification is not None: edge.append((msg_in_justification, message)) @@ -81,8 +80,15 @@ def build_viewgraph(self, view, validator_set, message_colors, message_labels, e sorted_validators = validator_set.sorted_by_name() for message in nodes: # Index of val in list may have some small performance concerns. - positions[message] = (float)(sorted_validators.index(message.sender) + 1) / \ - (float)(len(validator_set) + 1), 0.2 + 0.1*message.display_height + if message.estimate is not None: + xslot = sorted_validators.index(message.sender) + 1 + else: + xslot = (len(validator_set) + 1) / 2.0 + + positions[message] = ( + (float)(xslot) / (float)(len(validator_set) + 1), + 0.2 + 0.1 * message.display_height + ) node_color_map = {} for message in nodes: @@ -91,9 +97,8 @@ def build_viewgraph(self, view, validator_set, message_colors, message_labels, e elif message_colors[message] == len(validator_set) - 1: node_color_map[message] = "Black" else: - node_color_map[message] = COLOURS[int(len(COLOURS) * message_colors[message] / \ - len(validator_set))] - + node_color_map[message] = COLOURS[int(len(COLOURS) * message_colors[message] / + len(validator_set))] color_values = [node_color_map.get(node) for node in nodes] @@ -127,10 +132,9 @@ def build_viewgraph(self, view, validator_set, message_colors, message_labels, e ax.text(-0.05, 0.1, "Weights: ", fontsize=20) for validator in validator_set: - xpos = (float)(validator.name + 1)/(float)(len(validator_set) + 1) - 0.01 + xpos = (float)(validator.name + 1) / (float)(len(validator_set) + 1) - 0.01 ax.text(xpos, 0.1, (str)((int)(validator.weight)), fontsize=20) - def next_viewgraph( self, view, @@ -186,7 +190,6 @@ def make_thumbnails(self, frame_count_limit=IMAGE_LIMIT, xsize=1000, ysize=1000) for file_name in file_names: images.append(Image.open(self.graph_path + file_name)) - size = (xsize, ysize) iterator = 0 for image in images: @@ -194,7 +197,6 @@ def make_thumbnails(self, frame_count_limit=IMAGE_LIMIT, xsize=1000, ysize=1000) image.save(self.thumbnail_path + str(1000 + iterator) + "thumbnail.png", "PNG") iterator += 1 - def make_gif(self, frame_count_limit=IMAGE_LIMIT, gif_name="mygif.gif", frame_duration=0.4): """Make a GIF visualization of view graph.""" diff --git a/casper/binary/__init__.py b/casper/protocols/__init__.py similarity index 100% rename from casper/binary/__init__.py rename to casper/protocols/__init__.py diff --git a/casper/blockchain/__init__.py b/casper/protocols/binary/__init__.py similarity index 100% rename from casper/blockchain/__init__.py rename to casper/protocols/binary/__init__.py diff --git a/casper/binary/bet.py b/casper/protocols/binary/bet.py similarity index 58% rename from casper/binary/bet.py rename to casper/protocols/binary/bet.py index 19e88d9..5b47293 100644 --- a/casper/binary/bet.py +++ b/casper/protocols/binary/bet.py @@ -3,12 +3,11 @@ class Bet(Message): - """Message data structure for blockchain consensus""" + """Message data structure for binary consensus""" - def __init__(self, estimate, justification, sender): - # Do some type checking for safety! - assert estimate in {0, 1}, "... estimate should be binary!" - super().__init__(estimate, justification, sender) + @classmethod + def is_valid_estimate(cls, estimate): + return estimate in [0, 1] def conflicts_with(self, message): """Returns true if the other_message estimate is not the same as this estimate""" diff --git a/casper/binary/binary_estimator.py b/casper/protocols/binary/binary_estimator.py similarity index 91% rename from casper/binary/binary_estimator.py rename to casper/protocols/binary/binary_estimator.py index 70c7fb4..8f719c0 100644 --- a/casper/binary/binary_estimator.py +++ b/casper/protocols/binary/binary_estimator.py @@ -1,4 +1,6 @@ """The forkchoice module implements the estimator function a blockchain""" +import random as r + def get_estimate_from_latest_messages(latest_bets, default=None): """Picks the highest weight estimate (0 or 1) given some latest bets.""" @@ -11,6 +13,6 @@ def get_estimate_from_latest_messages(latest_bets, default=None): elif zero_weight < one_weight: return 1 elif zero_weight == 0: - return default + return r.randint(0, 1) else: raise Exception("Should be no ties!") diff --git a/casper/binary/binary_plot_tool.py b/casper/protocols/binary/binary_plot_tool.py similarity index 93% rename from casper/binary/binary_plot_tool.py rename to casper/protocols/binary/binary_plot_tool.py index b50d88a..d6e7e46 100644 --- a/casper/binary/binary_plot_tool.py +++ b/casper/protocols/binary/binary_plot_tool.py @@ -35,7 +35,7 @@ def update(self, message_paths=None, sent_messages=None, new_messages=None): self._update_message_labels(new_messages) def plot(self): - """Builds relevant edges to display and creates next viegraph using them""" + """Builds relevant edges to display and creates next viewgraph using them""" if self.first_time: self._update_first_message_labels() self.first_time = False @@ -53,7 +53,7 @@ def plot(self): ) def _update_first_message_labels(self): - for message in self.view.messages: + for message in self.view.justified_messages.values(): self.message_labels[message] = message.estimate def _update_communications(self, message_paths, sent_messages, new_messages): @@ -64,8 +64,8 @@ def _update_self_communications(self, new_messages): for validator in new_messages: message = new_messages[validator] - if validator in message.justification.latest_messages: - last_message = message.justification.latest_messages[validator] + if validator in message.justification: + last_message = self.view.justified_messages[message.justification[validator]] self.self_communications.append([last_message, message]) def _update_message_labels(self, new_messages): diff --git a/casper/protocols/binary/binary_protocol.py b/casper/protocols/binary/binary_protocol.py new file mode 100644 index 0000000..845b076 --- /dev/null +++ b/casper/protocols/binary/binary_protocol.py @@ -0,0 +1,27 @@ +import random as r + +from casper.protocols.binary.binary_view import BinaryView +from casper.protocols.binary.bet import Bet +from casper.protocols.integer.integer_plot_tool import IntegerPlotTool +from casper.protocol import Protocol + + +class BinaryProtocol(Protocol): + View = BinaryView + Message = Bet + PlotTool = IntegerPlotTool + + @staticmethod + def initial_message(validator): + if not validator: + return None + + rand_int = r.randint(0, 1) + + return Bet( + rand_int, + dict(), + validator, + 0, + 0 + ) diff --git a/casper/protocols/binary/binary_view.py b/casper/protocols/binary/binary_view.py new file mode 100644 index 0000000..35ad39d --- /dev/null +++ b/casper/protocols/binary/binary_view.py @@ -0,0 +1,8 @@ +"""The binary view module extends a view for binary data structures """ +from casper.protocols.integer.integer_view import IntegerView + + +class BinaryView(IntegerView): + """A view class that also keeps track of messages about a bit""" + def __init__(self, messages=None, first_message=None): + super().__init__(messages) diff --git a/tests/casper/blockchain/__init__.py b/casper/protocols/blockchain/__init__.py similarity index 100% rename from tests/casper/blockchain/__init__.py rename to casper/protocols/blockchain/__init__.py diff --git a/casper/blockchain/block.py b/casper/protocols/blockchain/block.py similarity index 78% rename from casper/blockchain/block.py rename to casper/protocols/blockchain/block.py index 6a96539..1752f65 100644 --- a/casper/blockchain/block.py +++ b/casper/protocols/blockchain/block.py @@ -5,11 +5,11 @@ class Block(Message): """Message data structure for blockchain consensus""" - def __init__(self, estimate, justification, sender): + def __init__(self, estimate, justification, sender, sequence_number, display_height): # Do some type checking for safety! assert isinstance(estimate, Block) or estimate is None, "...expected a prevblock!" - super().__init__(estimate, justification, sender) + super().__init__(estimate, justification, sender, sequence_number, display_height) # height is the traditional block height - number of blocks back to genesis block if estimate: @@ -17,6 +17,10 @@ def __init__(self, estimate, justification, sender): else: self.height = 1 + @classmethod + def is_valid_estimate(cls, estimate): + return isinstance(estimate, Block) or estimate is None + def conflicts_with(self, message): """Returns true if self is not in the prev blocks of other_message""" assert isinstance(message, Block), "...expected a block" diff --git a/casper/blockchain/blockchain_plot_tool.py b/casper/protocols/blockchain/blockchain_plot_tool.py similarity index 75% rename from casper/blockchain/blockchain_plot_tool.py rename to casper/protocols/blockchain/blockchain_plot_tool.py index 8c7797d..d16eaf2 100644 --- a/casper/blockchain/blockchain_plot_tool.py +++ b/casper/protocols/blockchain/blockchain_plot_tool.py @@ -12,29 +12,33 @@ def __init__(self, display, save, view, validator_set): super().__init__(display, save, 's') self.view = view self.validator_set = validator_set + self.genesis_block = self.view.genesis_block self.message_fault_tolerance = dict() self.blockchain = [] self.communications = [] + self.block_fault_tolerance = {} self.message_labels = {} + self.justifications = { + validator: [] + for validator in validator_set + } + + self.message_labels[self.genesis_block] = "G" - def update(self, message_paths=None, sent_messages=None, new_messages=None): + def update(self, new_messages=None): """Updates displayable items with new messages and paths""" - if message_paths is None: - message_paths = [] - if sent_messages is None: - sent_messages = dict() if new_messages is None: - new_messages = dict() + new_messages = [] - self._update_communications(message_paths, sent_messages, new_messages) + self._update_new_justifications(new_messages) self._update_blockchain(new_messages) self._update_block_fault_tolerance() self._update_message_labels(new_messages) def plot(self): - """Builds relevant edges to display and creates next viegraph using them""" + """Builds relevant edges to display and creates next viewgraph using them""" best_chain_edge = self.get_best_chain() validator_chain_edges = self.get_validator_chains() @@ -68,17 +72,23 @@ def get_validator_chains(self): return vals_chain_edges - def _update_communications(self, message_paths, sent_messages, new_messages): - for sender, receiver in message_paths: - self.communications.append([sent_messages[sender], new_messages[receiver]]) + def _update_new_justifications(self, new_messages): + for message in new_messages: + sender = message.sender + for validator in message.justification: + last_message = self.view.justified_messages[message.justification[validator]] + # only show if new justification + if last_message not in self.justifications[sender]: + self.communications.append([last_message, message]) + self.justifications[sender].append(last_message) def _update_blockchain(self, new_messages): - for message in new_messages.values(): + for message in new_messages: if message.estimate is not None: self.blockchain.append([message, message.estimate]) def _update_message_labels(self, new_messages): - for message in new_messages.values(): + for message in new_messages: self.message_labels[message] = message.sequence_number def _update_block_fault_tolerance(self): diff --git a/casper/protocols/blockchain/blockchain_protocol.py b/casper/protocols/blockchain/blockchain_protocol.py new file mode 100644 index 0000000..3af8c93 --- /dev/null +++ b/casper/protocols/blockchain/blockchain_protocol.py @@ -0,0 +1,18 @@ +from casper.protocols.blockchain.blockchain_view import BlockchainView +from casper.protocols.blockchain.block import Block +from casper.protocols.blockchain.blockchain_plot_tool import BlockchainPlotTool +from casper.protocol import Protocol + + +class BlockchainProtocol(Protocol): + View = BlockchainView + Message = Block + PlotTool = BlockchainPlotTool + + genesis_block = None + + @classmethod + def initial_message(cls, validator): + if not cls.genesis_block: + cls.genesis_block = Block(None, dict(), validator, -1, 0) + return cls.genesis_block diff --git a/casper/protocols/blockchain/blockchain_view.py b/casper/protocols/blockchain/blockchain_view.py new file mode 100644 index 0000000..b16b1fb --- /dev/null +++ b/casper/protocols/blockchain/blockchain_view.py @@ -0,0 +1,63 @@ +"""The blockchain view module extends a view for blockchain data structures """ +from casper.safety_oracles.clique_oracle import CliqueOracle +from casper.abstract_view import AbstractView +import casper.protocols.blockchain.forkchoice as forkchoice + + +class BlockchainView(AbstractView): + """A view class that also keeps track of a last_finalized_block and children""" + def __init__(self, messages=None, genesis_block=None): + self.children = dict() + self.last_finalized_block = genesis_block + self.genesis_block = genesis_block + + self._initialize_message_caches(messages) + + super().__init__(messages) + + def estimate(self): + """Returns the current forkchoice in this view""" + return forkchoice.get_fork_choice( + self.last_finalized_block, + self.children, + self.latest_messages + ) + + def update_safe_estimates(self, validator_set): + """Checks safety on messages in views forkchoice, and updates last_finalized_block""" + tip = self.estimate() + + while tip and tip != self.last_finalized_block: + oracle = CliqueOracle(tip, self, validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + if fault_tolerance > 0: + self.last_finalized_block = tip + self._update_when_finalized_cache(tip) + return self.last_finalized_block + + tip = tip.estimate + + def _update_protocol_specific_view(self, message): + """Given a now justified message, updates children and when_recieved""" + assert message.hash in self.justified_messages, "...should not have seen message!" + + # update the children dictonary with the new message + if message.estimate not in self.children: + self.children[message.estimate] = set() + self.children[message.estimate].add(message) + + self._update_when_added_cache(message) + + def _initialize_message_caches(self, messages): + self.when_added = {message: 0 for message in messages} + self.when_finalized = {self.genesis_block: 0} + + def _update_when_added_cache(self, message): + if message not in self.when_added: + self.when_added[message] = len(self.justified_messages) + + def _update_when_finalized_cache(self, tip): + while tip and tip not in self.when_finalized: + self.when_finalized[tip] = len(self.justified_messages) + tip = tip.estimate diff --git a/casper/blockchain/forkchoice.py b/casper/protocols/blockchain/forkchoice.py similarity index 100% rename from casper/blockchain/forkchoice.py rename to casper/protocols/blockchain/forkchoice.py diff --git a/casper/protocols/concurrent/__init__.py b/casper/protocols/concurrent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casper/protocols/concurrent/block.py b/casper/protocols/concurrent/block.py new file mode 100644 index 0000000..2aa252a --- /dev/null +++ b/casper/protocols/concurrent/block.py @@ -0,0 +1,48 @@ +"""The block module implements the message data structure for a concurrent protocol""" +from casper.message import Message + + +class Block(Message): + """Message data structure for concurrent consensus""" + + @classmethod + def is_valid_estimate(cls, estimate): + if not isinstance(estimate, dict): + return False + + for field in ['blocks', 'inputs', 'outputs']: + if field not in estimate: + return False + + if len(estimate) != 3: + return False + + if not isinstance(estimate['blocks'], set) or len(estimate['blocks']) < 1: + return False + + return True + + def conflicts_with(self, message): + """Returns true if self is not in the prev blocks of other_message""" + assert isinstance(message, Block), "...expected a block" + + return not self.is_in_history(message) + + def is_in_history(self, block): + """Returns True if self is an ancestor of block.""" + assert isinstance(block, Block), "...should be block, is" + + if self == block: + return True + + if len(block.estimate['blocks']) == 1: + for b in block.estimate['blocks']: + if b is None: + return False + + for b in block.estimate['blocks']: + # memoize in future for efficiency + if self.is_in_history(b): + return True + + return False diff --git a/casper/protocols/concurrent/concurrent_plot_tool.py b/casper/protocols/concurrent/concurrent_plot_tool.py new file mode 100644 index 0000000..964a45c --- /dev/null +++ b/casper/protocols/concurrent/concurrent_plot_tool.py @@ -0,0 +1,96 @@ +"""The concurrent plot tool implements functions for plotting concurrent data structures""" + +from casper.plot_tool import PlotTool +import casper.utils as utils + + +class ConcurrentPlotTool(PlotTool): + """The module contains functions for plotting a concurrent data structure""" + + def __init__(self, display, save, view, validator_set): + super().__init__(display, save, 's') + self.view = view + self.validator_set = validator_set + self.genesis_block = self.view.genesis_block + self.message_fault_tolerance = dict() + + self.schedule = [] + self.communications = [] + + self.block_fault_tolerance = {} + self.message_labels = {} + self.justifications = { + validator: [] + for validator in validator_set + } + + self.message_labels[self.genesis_block] = "G" + + def update(self, new_messages=None): + """Updates displayable items with new messages and paths""" + if new_messages is None: + new_messages = [] + + self._update_new_justifications(new_messages) + self._update_schedule(new_messages) + self._update_block_fault_tolerance() + self._update_message_labels(new_messages) + + def plot(self): + """Builds relevant edges to display and creates next viewgraph using them""" + + best_schedule_edge = self.get_best_schedule() + + validator_chain_edges = self.get_validator_chains() + + edgelist = [] + edgelist.append(utils.edge(self.schedule, 2, 'grey', 'solid')) + edgelist.append(utils.edge(self.communications, 1, 'black', 'dotted')) + edgelist.append(best_schedule_edge) + edgelist.extend(validator_chain_edges) + + self.next_viewgraph( + self.view, + self.validator_set, + edges=edgelist, + message_colors=self.block_fault_tolerance, + message_labels=self.message_labels + ) + + def get_best_schedule(self): + """Returns an edge made of the global forkchoice to genesis""" + best_messages = self.view.estimate()['blocks'] + best_schedule = utils.build_schedule(best_messages) + return utils.edge(best_schedule, 5, 'red', 'solid') + + def get_validator_chains(self): + """Returns a list of edges main from validators current forkchoice to genesis""" + vals_chain_edges = [] + for validator in self.validator_set: + chain = utils.build_schedule(set([validator.my_latest_message()])) + vals_chain_edges.append(utils.edge(chain, 2, 'blue', 'solid')) + + return vals_chain_edges + + def _update_new_justifications(self, new_messages): + for message in new_messages: + sender = message.sender + for validator in message.justification: + last_message = self.view.justified_messages[message.justification[validator]] + # only show if new justification + if last_message not in self.justifications[sender]: + self.communications.append([last_message, message]) + self.justifications[sender].append(last_message) + + def _update_schedule(self, new_messages): + for message in new_messages: + for ancestor in message.estimate['blocks']: + if ancestor is not None: + self.schedule.append([message, ancestor]) + + def _update_message_labels(self, new_messages): + for message in new_messages: + self.message_labels[message] = message.sequence_number + + def _update_block_fault_tolerance(self): + return diff --git a/casper/protocols/concurrent/concurrent_protocol.py b/casper/protocols/concurrent/concurrent_protocol.py new file mode 100644 index 0000000..c301168 --- /dev/null +++ b/casper/protocols/concurrent/concurrent_protocol.py @@ -0,0 +1,26 @@ +import random as r + +from casper.protocols.concurrent.concurrent_view import ConcurrentView +from casper.protocols.concurrent.block import Block +from casper.protocols.concurrent.concurrent_plot_tool import ConcurrentPlotTool + +from casper.protocol import Protocol + + +class ConcurrentProtocol(Protocol): + View = ConcurrentView + Message = Block + PlotTool = ConcurrentPlotTool + + genesis_block = None + + @classmethod + def initial_message(cls, validator): + if not cls.genesis_block: + blocks = set([None]) + inputs = set([r.randint(0, 1000000000) for x in range(7)]) + outputs = set([r.randint(0, 1000000000) for x in inputs]) + + estimate = {'blocks': blocks, 'inputs': inputs, 'outputs': outputs} + cls.genesis_block = Block(estimate, dict(), validator, -1, 0) + return cls.genesis_block diff --git a/casper/protocols/concurrent/concurrent_view.py b/casper/protocols/concurrent/concurrent_view.py new file mode 100644 index 0000000..12b526e --- /dev/null +++ b/casper/protocols/concurrent/concurrent_view.py @@ -0,0 +1,71 @@ +"""The concurrent view module extends a view for concurrent data structures """ +import random as r + +from casper.abstract_view import AbstractView +import casper.protocols.concurrent.forkchoice as forkchoice + + +class ConcurrentView(AbstractView): + """A view class that also keeps track of a last_finalized_estimate and children""" + def __init__(self, messages=None, genesis_block=None): + self.children = dict() + self.last_finalized_estimate = set([genesis_block]) + self.genesis_block = genesis_block + + self._initialize_message_caches(messages) + + # In the future, can change this to any function that follows the interface + self.select_outputs = self.select_random_outputs_to_consume + self.create_outputs = self.create_random_new_outputs + + super().__init__(messages) + + def estimate(self): + """Returns the current forkchoice in this view""" + available_outputs, output_sources = forkchoice.get_fork_choice( + self.last_finalized_estimate, + self.children, + self.latest_messages + ) + + old_outputs = self.select_outputs(available_outputs, output_sources) + new_outputs = self.create_outputs(old_outputs, len(old_outputs)) + blocks = {output_sources[output] for output in old_outputs} + + return {'blocks': blocks, 'inputs': old_outputs, 'outputs': new_outputs} + + def select_random_outputs_to_consume(self, available_outputs, output_sources): + num_outputs = r.randint(1, len(available_outputs)) + return set(r.sample(available_outputs, num_outputs)) + + def create_random_new_outputs(self, old_outputs, num_new_outputs): + return set([r.randint(0, 1000000000) for _ in range(num_new_outputs)]) + + def update_safe_estimates(self, validator_set): + """Checks safety on messages in views forkchoice, and updates last_finalized_estimate""" + pass + + def _update_protocol_specific_view(self, message): + """Given a now justified message, updates children and when_recieved""" + assert message.hash in self.justified_messages, "...should not have seen message!" + + # update the children dictonary with the new message + for ancestor in message.estimate['blocks']: + if ancestor not in self.children: + self.children[ancestor] = set() + self.children[ancestor].add(message) + + self._update_when_added_cache(message) + + def _initialize_message_caches(self, messages): + self.when_added = {message: 0 for message in messages} + self.when_finalized = {self.genesis_block: 0} + + def _update_when_added_cache(self, message): + if message not in self.when_added: + self.when_added[message] = len(self.justified_messages) + + def _update_when_finalized_cache(self, tip): + while tip and tip not in self.when_finalized: + self.when_finalized[tip] = len(self.justified_messages) + tip = tip.estimate diff --git a/casper/protocols/concurrent/forkchoice.py b/casper/protocols/concurrent/forkchoice.py new file mode 100644 index 0000000..ad3030d --- /dev/null +++ b/casper/protocols/concurrent/forkchoice.py @@ -0,0 +1,106 @@ +"""The forkchoice module implements the estimator function a concurrent schedule""" + + +def get_ancestors(block): + ancestors = set() + stack = [block] + + while any(stack): + curr_block = stack.pop() + + if curr_block is None: + continue + + if curr_block not in ancestors: + ancestors.add(curr_block) + stack.extend([b for b in curr_block.estimate['blocks']]) + + return ancestors + + +def get_scores(latest_messages): + scores = dict() + + for validator in latest_messages: + ancestors = get_ancestors(latest_messages[validator]) + for b in ancestors: + scores[b] = scores.get(b, 0) + validator.weight + + return scores + + +def get_outputs(blocks): + outputs = set() + + for block in blocks: + outputs.update(block.estimate['outputs']) + + return outputs + + +def update_outputs(outputs, blocks): + for block in blocks: + for output in block.estimate['inputs']: + outputs.remove(output) + for output in block.estimate['outputs']: + outputs.add(output) + + +def track_output_sources(output_sources, new_blocks): + for block in new_blocks: + for output in block.estimate['outputs']: + assert output not in output_sources # only should be spent once... + output_sources[output] = block + + +def is_consumable(block, current_blocks, scores, available_outputs): + for other_block in current_blocks: + if any(block.estimate['inputs'].intersection(other_block.estimate['inputs'])): + if scores.get(block, 0) < scores.get(other_block, 0): + return False + # we can't eat a block if it's outputs are not yet available + for output in block.estimate['inputs']: + if output not in available_outputs: + return False + + return True + + +def get_children(blocks, children_dict): + children_blocks = set() + + for block in blocks: + if block in children_dict: + children_blocks.update(children_dict[block]) + + return children_blocks + + +def get_fork_choice(last_finalized_estimate, children, latest_messages): + """Returns the estimate by selecting highest weight sub-trees. + Starts from the last_finalized_estimate and stops when it reaches a tips.""" + output_sources = dict() + available_outputs = set() # should start w/ all the stuff from the last finalized estimate... + for block in last_finalized_estimate: + available_outputs.update(block.estimate['inputs']) + + scores = get_scores(latest_messages) + + current_blocks = last_finalized_estimate # this is a set of blocks + track_output_sources(output_sources, current_blocks) + update_outputs(available_outputs, current_blocks) + current_children = get_children(current_blocks, children) + + while any(current_children): + next_blocks = set() + + for block in current_children: + if is_consumable(block, current_children, scores, available_outputs): + next_blocks.add(block) + + current_blocks = next_blocks + track_output_sources(output_sources, current_blocks) + update_outputs(available_outputs, current_blocks) + current_children = get_children(current_blocks, children) + + return available_outputs, output_sources diff --git a/casper/protocols/integer/__init__.py b/casper/protocols/integer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casper/protocols/integer/bet.py b/casper/protocols/integer/bet.py new file mode 100644 index 0000000..c1cbe0a --- /dev/null +++ b/casper/protocols/integer/bet.py @@ -0,0 +1,16 @@ +"""The Bet module implements the message data structure for integer consensus""" +from casper.message import Message + + +class Bet(Message): + """Message data structure for integer consensus""" + + @classmethod + def is_valid_estimate(cls, estimate): + return isinstance(estimate, int) + + def conflicts_with(self, message): + """Returns true if the other_message estimate is not the same as this estimate""" + assert isinstance(message.estimate, int), "... estimate should be an integer!" + + return self.estimate != message.estimate diff --git a/casper/protocols/integer/integer_estimator.py b/casper/protocols/integer/integer_estimator.py new file mode 100644 index 0000000..fd33e81 --- /dev/null +++ b/casper/protocols/integer/integer_estimator.py @@ -0,0 +1,16 @@ +"""The integer estimator module implements the estimator function integer consensus""" + + +def get_estimate_from_latest_messages(latest_bets): + """Picks the median weight estimate given some latest bets.""" + + sorted_bets = sorted(latest_bets.values(), key=lambda bet: bet.estimate) + half_seen_weight = sum(v.weight for v in latest_bets) / 2.0 + + assert half_seen_weight > 0 + + total_estimate_weight = 0 + for bet in sorted_bets: + total_estimate_weight += bet.sender.weight + if total_estimate_weight >= half_seen_weight: + return bet.estimate diff --git a/casper/protocols/integer/integer_plot_tool.py b/casper/protocols/integer/integer_plot_tool.py new file mode 100644 index 0000000..1a06261 --- /dev/null +++ b/casper/protocols/integer/integer_plot_tool.py @@ -0,0 +1,85 @@ +"""The integer plot tool implements functions for plotting integer consensus""" + +from casper.plot_tool import PlotTool +from casper.safety_oracles.clique_oracle import CliqueOracle +import casper.utils as utils + + +class IntegerPlotTool(PlotTool): + """The module contains functions for plotting an integer data structure""" + + def __init__(self, display, save, view, validator_set): + super().__init__(display, save, 'o') + self.view = view + self.validator_set = validator_set + + self.new_justifications = [] + self.bet_fault_tolerance = {} + self.message_labels = {} + self.justifications = { + validator: [] + for validator in validator_set + } + + self.first_time = True + + def update(self, new_messages=None): + """Updates displayable items with new messages and paths""" + if new_messages is None: + new_messages = [] + + self._update_new_justifications(new_messages) + self._update_message_fault_tolerance() + self._update_message_labels(new_messages) + + def plot(self): + """Builds relevant edges to display and creates next viewgraph using them""" + if self.first_time: + self._update_first_message_labels() + self.first_time = False + + edgelist = [] + edgelist.append(utils.edge(self.new_justifications, 1, 'black', 'solid')) + + self.next_viewgraph( + self.view, + self.validator_set, + edges=edgelist, + message_colors=self.bet_fault_tolerance, + message_labels=self.message_labels + ) + + def _update_first_message_labels(self): + for message in self.view.justified_messages.values(): + self.message_labels[message] = message.estimate + + def _update_new_justifications(self, new_messages): + for message in new_messages: + sender = message.sender + for validator in message.justification: + last_message = self.view.justified_messages[message.justification[validator]] + # only show if new justification + if last_message not in self.justifications[sender]: + self.new_justifications.append([last_message, message]) + self.justifications[sender].append(last_message) + # always show self as justification + elif last_message.sender == message.sender: + self.new_justifications.append([last_message, message]) + + def _update_message_labels(self, new_messages): + for message in new_messages: + self.message_labels[message] = message.estimate + + def _update_message_fault_tolerance(self): + for validator in self.view.latest_messages: + + latest_message = self.view.latest_messages[validator] + + if latest_message in self.bet_fault_tolerance: + continue + + oracle = CliqueOracle(latest_message, validator.view, self.validator_set) + fault_tolerance, num_node_ft = oracle.check_estimate_safety() + + if fault_tolerance > 0: + self.bet_fault_tolerance[latest_message] = num_node_ft diff --git a/casper/protocols/integer/integer_protocol.py b/casper/protocols/integer/integer_protocol.py new file mode 100644 index 0000000..5beff9c --- /dev/null +++ b/casper/protocols/integer/integer_protocol.py @@ -0,0 +1,27 @@ +import random as r + +from casper.protocols.integer.integer_view import IntegerView +from casper.protocols.integer.bet import Bet +from casper.protocols.integer.integer_plot_tool import IntegerPlotTool +from casper.protocol import Protocol + + +class IntegerProtocol(Protocol): + View = IntegerView + Message = Bet + PlotTool = IntegerPlotTool + + @staticmethod + def initial_message(validator): + if not validator: + return None + + rand_int = r.randint(0, 100) + + return Bet( + rand_int, + dict(), + validator, + 0, + 0 + ) diff --git a/casper/protocols/integer/integer_view.py b/casper/protocols/integer/integer_view.py new file mode 100644 index 0000000..fe9c766 --- /dev/null +++ b/casper/protocols/integer/integer_view.py @@ -0,0 +1,31 @@ +"""The integer view module extends a view for integer data structures """ +from casper.safety_oracles.clique_oracle import CliqueOracle +from casper.abstract_view import AbstractView +import casper.protocols.integer.integer_estimator as estimator + + +class IntegerView(AbstractView): + """A view class for integer values that also keeps track of a last_finalized_estimate""" + def __init__(self, messages=None, first_message=None): + super().__init__(messages) + + self.last_finalized_estimate = None + + def estimate(self): + """Returns the current forkchoice in this view""" + return estimator.get_estimate_from_latest_messages( + self.latest_messages + ) + + def update_safe_estimates(self, validator_set): + """Checks safety on most recent created by this view""" + # check estimate safety on the most + for bet in self.latest_messages.values(): + oracle = CliqueOracle(bet, self, validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + if fault_tolerance > 0: + if self.last_finalized_estimate: + assert not self.last_finalized_estimate.conflicts_with(bet) + self.last_finalized_estimate = bet + break diff --git a/casper/protocols/order/__init__.py b/casper/protocols/order/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casper/protocols/order/bet.py b/casper/protocols/order/bet.py new file mode 100644 index 0000000..2e028ac --- /dev/null +++ b/casper/protocols/order/bet.py @@ -0,0 +1,16 @@ +"""The Bet module implements the message data structure for order consensus""" +from casper.message import Message + + +class Bet(Message): + """Message data structure for order consensus""" + + @classmethod + def is_valid_estimate(cls, estimate): + return isinstance(estimate, list) + + def conflicts_with(self, message): + """Returns true if the other_message estimate is not the same as this estimate""" + assert isinstance(message.estimate, list) + + return self.estimate != message.estimate diff --git a/casper/protocols/order/order_estimator.py b/casper/protocols/order/order_estimator.py new file mode 100644 index 0000000..54d749b --- /dev/null +++ b/casper/protocols/order/order_estimator.py @@ -0,0 +1,10 @@ +def get_estimate_from_latest_messages(latest_bets): + sample_list = list(latest_bets.values())[0].estimate + elem_weights = {elem: 0 for elem in sample_list} + for validator in latest_bets: + bet = latest_bets[validator] + estimate = bet.estimate + for i, elem in enumerate(estimate): + elem_weights[elem] += validator.weight * (len(estimate) - i) + + return sorted(elem_weights, key=lambda elem: elem_weights[elem], reverse=True) diff --git a/casper/protocols/order/order_plot_tool.py b/casper/protocols/order/order_plot_tool.py new file mode 100644 index 0000000..b12e7c2 --- /dev/null +++ b/casper/protocols/order/order_plot_tool.py @@ -0,0 +1,47 @@ +"""The order plot tool implements functions for plotting order consensus +NOTE: currently only prints to terminal instead of plotting +""" + +from casper.plot_tool import PlotTool + + +class OrderPlotTool(PlotTool): + """The module contains functions for plotting an order data structure""" + + def __init__(self, display, save, view, validator_set): + self.view = view + + print("initial validator bets:") + for validator in validator_set.sorted_by_name(): + print("{} [{}]:\t{}".format(validator.name, round(validator.weight, 2), validator.estimate())) + print() + + def update(self, message_paths=None, sent_messages=None, new_messages=None): + if message_paths is None: + message_paths = [] + if sent_messages is None: + sent_messages = dict() + if new_messages is None: + new_messages = dict() + + def plot(self): + print("{}:\t{}".format(round(self.view.last_fault_tolerance, 1), self.view.estimate())) + + def next_viewgraph( + self, + view, + validator_set, + message_colors=None, + message_labels=None, + edges=None + ): + pass + + def build_viewgraph(self, view, validator_set, message_colors, message_labels, edges): + pass + + def make_thumbnails(self, frame_count_limit=None, xsize=None, ysize=None): + pass + + def make_gif(self, frame_count_limit=None, gif_name=None, frame_duration=None): + pass diff --git a/casper/protocols/order/order_protocol.py b/casper/protocols/order/order_protocol.py new file mode 100644 index 0000000..c513257 --- /dev/null +++ b/casper/protocols/order/order_protocol.py @@ -0,0 +1,31 @@ +import random as r +import copy + +from casper.protocols.order.order_view import OrderView +from casper.protocols.order.bet import Bet +from casper.protocols.order.order_plot_tool import OrderPlotTool +from casper.protocol import Protocol + + +class OrderProtocol(Protocol): + View = OrderView + Message = Bet + PlotTool = OrderPlotTool + + LIST = ["dog", "frog", "horse", "pig", "rat", "whale", "cat"] + + @staticmethod + def initial_message(validator): + if not validator: + return None + + rand_order_list = copy.deepcopy(OrderProtocol.LIST) + r.shuffle(rand_order_list) + + return Bet( + rand_order_list, + dict(), + validator, + 0, + 0 + ) diff --git a/casper/protocols/order/order_view.py b/casper/protocols/order/order_view.py new file mode 100644 index 0000000..1fa8e68 --- /dev/null +++ b/casper/protocols/order/order_view.py @@ -0,0 +1,33 @@ +"""The order view module extends a view for order data structures """ +from casper.safety_oracles.clique_oracle import CliqueOracle +from casper.abstract_view import AbstractView +import casper.protocols.order.order_estimator as estimator + + +class OrderView(AbstractView): + """A view class that also keeps track of a last_finalized_estimate""" + def __init__(self, messages=None, first_message=None): + super().__init__(messages) + + self.last_finalized_estimate = None + self.last_fault_tolerance = 0 + + def estimate(self): + """Returns the current forkchoice in this view""" + return estimator.get_estimate_from_latest_messages( + self.latest_messages + ) + + def update_safe_estimates(self, validator_set): + """Checks safety on most recent created by this view""" + # check estimate safety on the most + for bet in self.latest_messages.values(): + oracle = CliqueOracle(bet, self, validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + if fault_tolerance > 0: + if self.last_finalized_estimate: + assert not self.last_finalized_estimate.conflicts_with(bet) + self.last_fault_tolerance = fault_tolerance + self.last_finalized_estimate = bet + break diff --git a/casper/safety_oracles/adversary_models/model_validator.py b/casper/safety_oracles/adversary_models/model_validator.py index d3fc761..702b52a 100644 --- a/casper/safety_oracles/adversary_models/model_validator.py +++ b/casper/safety_oracles/adversary_models/model_validator.py @@ -30,7 +30,6 @@ def show(self, bet): self.latest_observed_bets[bet.sender] = bet - def make_new_latest_bet(self): """This function attempts to make a new latest bet for this validator (self) with a given estimate.""" diff --git a/casper/safety_oracles/adversary_oracle.py b/casper/safety_oracles/adversary_oracle.py index 26f6849..6735595 100644 --- a/casper/safety_oracles/adversary_oracle.py +++ b/casper/safety_oracles/adversary_oracle.py @@ -49,13 +49,14 @@ def get_messages_and_viewables(self): for val2 in self.validator_set: # if they have seen nothing from some validator, assume the worst - if val2 not in val_latest_message.justification.latest_messages: + if val2 not in val_latest_message.justification: viewables[validator][val2] = ModelBet(AdversaryOracle.ADV_ESTIMATE, val2) continue # If they have seen something from other validators, do a free block check # If there is a free block, assume they will see that (side-effects free!) - val2_msg_in_v_view = val_latest_message.justification.latest_messages[val2] + message_hash = val_latest_message.justification[val2] + val2_msg_in_v_view = self.view.justified_messages[message_hash] if utils.exists_free_message( self.candidate_estimate, val2, diff --git a/casper/safety_oracles/clique_oracle.py b/casper/safety_oracles/clique_oracle.py index b8c34f2..7558e47 100644 --- a/casper/safety_oracles/clique_oracle.py +++ b/casper/safety_oracles/clique_oracle.py @@ -28,19 +28,21 @@ def _collect_edges(self): for val1, val2 in itertools.combinations(self.with_candidate, 2): # the latest message val1 has seen from val2 is on the candidate estimate, v1_msg = self.view.latest_messages[val1] - if val2 not in v1_msg.justification.latest_messages: + if val2 not in v1_msg.justification: continue - v2_msg_in_v1_view = v1_msg.justification.latest_messages[val2] + message_hash = v1_msg.justification[val2] + v2_msg_in_v1_view = self.view.justified_messages[message_hash] if self.candidate_estimate.conflicts_with(v2_msg_in_v1_view): continue # the latest block val2 has seen from val1 is on the candidate estimate v2_msg = self.view.latest_messages[val2] - if val1 not in v2_msg.justification.latest_messages: + if val1 not in v2_msg.justification: continue - v1_msg_in_v2_view = v2_msg.justification.latest_messages[val1] + message_hash = v2_msg.justification[val1] + v1_msg_in_v2_view = self.view.justified_messages[message_hash] if self.candidate_estimate.conflicts_with(v1_msg_in_v2_view): continue diff --git a/casper/utils.py b/casper/utils.py index b803829..c77ea3c 100644 --- a/casper/utils.py +++ b/casper/utils.py @@ -15,7 +15,8 @@ def exists_free_message(estimate, val, sequence_num, view): if curr_message.sequence_number == 0: break - curr_message = curr_message.justification.latest_messages[val] + next_message_hash = curr_message.justification[val] + curr_message = view.justified_messages[next_message_hash] return False @@ -49,3 +50,23 @@ def build_chain(tip, base): next_block = next_block.estimate return chain + + +def build_schedule(tip): + """Returns a list of blocks and blocks estimates from tip to base.""" + stack = [block for block in tip] + schedule = [] + + while any(stack): + curr_block = stack.pop() + + if curr_block is None: + continue + + for ancestor in curr_block.estimate['blocks']: + if ancestor is None: + continue + schedule.append((curr_block, ancestor)) + stack.append(ancestor) + + return schedule diff --git a/casper/validator.py b/casper/validator.py index d6c01e0..819d2ee 100644 --- a/casper/validator.py +++ b/casper/validator.py @@ -1,6 +1,7 @@ """The validator module contains the Validator class, which creates/sends/recieves messages """ import numbers -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol + class Validator(object): """A validator has a view from which it generates new messages and detects finalized blocks.""" @@ -14,8 +15,23 @@ def __init__(self, name, weight, protocol=BlockchainProtocol, validator_set=None self.name = name self.weight = weight - self.view = protocol.View(set()) self.validator_set = validator_set + self.protocol = protocol + + self.initial_message = protocol.initial_message(self) + self.view = protocol.View(set([self.initial_message]), self.initial_message) + + def __eq__(self, val): + if val is None: + return False + if not isinstance(val, Validator): + return False + return hash(self) == hash(val) + + def __hash__(self): + # defined differently than self.hash to avoid confusion with builtin + # use of __hash__ in dictionaries, sets, etc + return hash(self.name) def receive_messages(self, messages): """Allows the validator to receive protocol messages.""" @@ -30,7 +46,7 @@ def my_latest_message(self): """This function returns the validator's latest message.""" if self in self.view.latest_messages: return self.view.latest_messages[self] - raise KeyError("Validator has not previously created a message") + return None def update_safe_estimates(self): """The validator checks estimate safety on some estimate with some safety oracle.""" @@ -39,4 +55,40 @@ def update_safe_estimates(self): def make_new_message(self): """This function produces a new latest message for the validator. It updates the validator's latest message, estimate, view, and latest observed messages.""" - return self.view.make_new_message(self) + new_message = self.protocol.Message( + self.estimate(), + self.justification(), + self, + self._next_sequence_number(), + self._next_display_height() + ) + self.view.add_messages(set([new_message])) + assert new_message.hash in self.view.justified_messages # sanity check + + return new_message + + def justification(self): + """Returns the headers of latest message seen from other validators.""" + latest_message_headers = dict() + for validator in self.view.latest_messages: + latest_message_headers[validator] = self.view.latest_messages[validator].hash + return latest_message_headers + + def _next_sequence_number(self): + """Returns the sequence number for the next message from a validator""" + last_message = self.my_latest_message() + + if last_message: + return last_message.sequence_number + 1 + return 0 + + def _next_display_height(self): + """Returns the display height for a message created in this view""" + if not any(self.view.latest_messages): + return 0 + + max_height = max( + self.view.latest_messages[validator].display_height + for validator in self.view.latest_messages + ) + return max_height + 1 diff --git a/casper/validator_set.py b/casper/validator_set.py index 6bc5b36..18d5891 100644 --- a/casper/validator_set.py +++ b/casper/validator_set.py @@ -1,12 +1,15 @@ """The validator set module contains the ValidatorSet class """ from casper.validator import Validator -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol class ValidatorSet: """Defines the validator set.""" def __init__(self, weights, protocol=BlockchainProtocol): - self.validators = {Validator(name, weights[name], protocol, self) for name in weights} + self.validators = { + Validator(name, weights[name], protocol, self) + for name in weights + } def __len__(self): return len(self.validators) diff --git a/config.ini b/config.ini index fcba5ec..90294fb 100644 --- a/config.ini +++ b/config.ini @@ -2,4 +2,7 @@ NumValidators: 5 NumRounds: 100 ReportInterval: 20 -DefaultProtocol: blockchain \ No newline at end of file +DefaultProtocol: blockchain +DefaultNetwork: no-delay +Save: true +JustifyMessages: true diff --git a/conftest.py b/conftest.py index 4351a4d..de62c13 100644 --- a/conftest.py +++ b/conftest.py @@ -1,28 +1,59 @@ +import random as r import pytest -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.network import Network -from casper.validator import Validator +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.integer.integer_protocol import IntegerProtocol +from casper.protocols.binary.binary_protocol import BinaryProtocol +from casper.protocols.order.order_protocol import OrderProtocol + +from casper.networks import ( + ConstantDelayNetwork, + NoDelayNetwork +) -from simulations.testing_language import TestLangCBC from simulations.utils import generate_random_gaussian_validator_set +PROTOCOLS = [BlockchainProtocol, BinaryProtocol, IntegerProtocol, OrderProtocol] +GENESIS_PROTOCOLS = [BlockchainProtocol] +RAND_START_PROTOCOLS = [BinaryProtocol, IntegerProtocol, OrderProtocol] + + def pytest_addoption(parser): parser.addoption("--report", action="store_true", default=False, help="plot TestLangCBC tests") -def run_test_lang_with_reports(test_string, weights): - TestLangCBC(weights, BlockchainProtocol, True).parse(test_string) +@pytest.fixture(params=PROTOCOLS) +def protocol(request): + return request.param -def run_test_lang_without_reports(test_string, weights): - TestLangCBC(weights, BlockchainProtocol, False).parse(test_string) +@pytest.fixture(params=GENESIS_PROTOCOLS) +def genesis_protocol(request): + return request.param -def random_gaussian_validator_set_from_protocol(protocol=BlockchainProtocol): - return generate_random_gaussian_validator_set(protocol) +@pytest.fixture(params=RAND_START_PROTOCOLS) +def rand_start_protocol(request): + return request.param + + +@pytest.fixture +def message(protocol): + return protocol.Message + + +@pytest.fixture +def example_function(): + def example_func(): + return + return example_func + + +@pytest.fixture(autouse=True) +def reset_blockchain_protocol(request): + BlockchainProtocol.genesis_block = None @pytest.fixture @@ -31,30 +62,63 @@ def report(request): @pytest.fixture -def test_lang_runner(report): - if report: - return run_test_lang_with_reports - else: - return run_test_lang_without_reports +def empty_just(): + return {} + + +@pytest.fixture +def test_weight(): + return {i: 5 - i for i in range(5)} @pytest.fixture def generate_validator_set(): - return random_gaussian_validator_set_from_protocol + return generate_random_gaussian_validator_set + + +@pytest.fixture +def validator_set(protocol): + return generate_random_gaussian_validator_set(protocol) + + +@pytest.fixture +def validator(validator_set): + return r.choice(list(validator_set.validators)) + + +@pytest.fixture +def to_from_validators(validator_set): + return r.sample( + validator_set.validators, + 2 + ) + + +@pytest.fixture +def to_validator(to_from_validators): + return to_from_validators[0] + + +@pytest.fixture +def from_validator(to_from_validators): + return to_from_validators[1] + + +@pytest.fixture +def network(validator_set, protocol): + return NoDelayNetwork(validator_set, protocol) @pytest.fixture -def validator_set(): - return random_gaussian_validator_set_from_protocol(BlockchainProtocol) +def no_delay_network(validator_set, protocol): + return NoDelayNetwork(validator_set, protocol) @pytest.fixture -def validator(): - return Validator("Name", 15.5) +def constant_delay_network(validator_set, protocol): + return ConstantDelayNetwork(validator_set, protocol) @pytest.fixture -def network(validator_set): - network = Network(validator_set) - network.random_initialization() - return network +def global_view(network): + return network.global_view diff --git a/experiments/binary/test.json b/experiments/binary/test.json index 5b41f16..f2f6807 100644 --- a/experiments/binary/test.json +++ b/experiments/binary/test.json @@ -1,6 +1,7 @@ { "msg_mode": "rand", "protocol": "binary", + "network": "linear", "num_simulations": 5, "rounds_per_sim": 40, "report_interval": 10, diff --git a/experiments/blockchain/WHALE_rand-mes-gen_50-sims_500-rounds_fixed-weights.json b/experiments/blockchain/WHALE_rand-mes-gen_50-sims_500-rounds_fixed-weights.json index 3384163..ecd3ec5 100644 --- a/experiments/blockchain/WHALE_rand-mes-gen_50-sims_500-rounds_fixed-weights.json +++ b/experiments/blockchain/WHALE_rand-mes-gen_50-sims_500-rounds_fixed-weights.json @@ -1,6 +1,7 @@ { "msg_mode": "rand", "protocol": "blockchain", + "network": "sync", "num_simulations": 50, "rounds_per_sim": 500, "report_interval": 20, diff --git a/experiments/blockchain/fast_test.json b/experiments/blockchain/fast_test.json index 6707107..adcc285 100644 --- a/experiments/blockchain/fast_test.json +++ b/experiments/blockchain/fast_test.json @@ -1,6 +1,7 @@ { "msg_mode": "rand", "protocol": "blockchain", + "network": "sync", "num_simulations": 5, "rounds_per_sim": 40, "report_interval": 10, diff --git a/experiments/blockchain/rand-mes-gen_500-rounds_5-gauss-validators.json b/experiments/blockchain/rand-mes-gen_500-rounds_5-gauss-validators.json index 3b9572e..055d0f0 100644 --- a/experiments/blockchain/rand-mes-gen_500-rounds_5-gauss-validators.json +++ b/experiments/blockchain/rand-mes-gen_500-rounds_5-gauss-validators.json @@ -1,6 +1,7 @@ { "msg_mode": "rand", "protocol": "blockchain", + "network": "gaussian", "num_simulations": 5, "rounds_per_sim": 500, "report_interval": 20, diff --git a/run_experiment.py b/run_experiment.py index ffc592e..3f3fb03 100644 --- a/run_experiment.py +++ b/run_experiment.py @@ -7,6 +7,7 @@ from simulations.experiment import Experiment from simulations.utils import ( + select_network, select_protocol, validator_generator ) @@ -34,6 +35,7 @@ def main(): experiment_name = "{}-{}".format(file_name, timestamp()) protocol = select_protocol(config['protocol']) + network_class = select_network(config['network']) experiment = Experiment( experiment_name, @@ -42,6 +44,7 @@ def main(): validator_generator(config['validator_info'], protocol), config['msg_mode'], protocol, + network_class, config['rounds_per_sim'], config['report_interval'] ) diff --git a/simulations/analyzer.py b/simulations/analyzer.py index 1cfeb27..29f842c 100644 --- a/simulations/analyzer.py +++ b/simulations/analyzer.py @@ -8,40 +8,49 @@ def __init__(self, simulation): self.simulation = simulation self.global_view = simulation.network.global_view + @property def num_messages(self): return len(self.messages()) + @property def num_safe_messages(self): return len(self.safe_messages()) + @property def num_unsafe_messages(self): return len(self.unsafe_messages()) + @property def num_bivalent_messages(self): return len(self.bivalent_messages()) + @property def prop_safe_messages(self): - return float(self.num_safe_messages()) / self.num_messages() + return float(self.num_safe_messages) / self.num_messages + @property def safe_to_tip_length(self): - return self.global_view.estimate().height - self.safe_tip_height() + return self.global_view.estimate().height - self.safe_tip_height + @property def safe_tip_height(self): - if self.safe_tip(): - return self.safe_tip().height + if self.safe_tip: + return self.safe_tip.height return 0 + @property def bivalent_message_depth(self): max_height = max( message.height for message in self.global_view.latest_messages.values() ) - return max_height - self.safe_tip_height() + return max_height - self.safe_tip_height + @property def bivalent_message_branching_factor(self): to_check = set(self.bivalent_messages()) - if self.safe_tip(): - to_check.add(self.safe_tip()) + if self.safe_tip: + to_check.add(self.safe_tip) check = to_check.pop() branches = 0 @@ -58,11 +67,34 @@ def bivalent_message_branching_factor(self): return 0 return branches / num_checked + @property def safe_tip(self): return self.global_view.last_finalized_block + @property + def latency_to_finality(self): + safe_messages = self.safe_messages() + + if not any(safe_messages): + return None + + individual_latency = [ + self.global_view.when_finalized[message] - self.global_view.when_added[message] + for message in safe_messages + ] + + return statistics.mean(individual_latency) + + @property + def orphan_rate(self): + num_unsafe_messages = self.num_unsafe_messages + num_safe_messages = self.num_safe_messages + if num_unsafe_messages + num_safe_messages == 0: + return 0 + return float(num_unsafe_messages) / (num_unsafe_messages + num_safe_messages) + def messages(self): - return self.global_view.messages + return set(self.global_view.justified_messages.values()) def safe_messages(self): if not self.global_view.last_finalized_block: @@ -81,30 +113,10 @@ def bivalent_messages(self): def unsafe_messages(self): potential = self.messages() - self.safe_messages() - if not self.safe_tip(): + if not self.safe_tip: return set() return { message for message in potential - if message.height <= self.safe_tip().height + if message.height <= self.safe_tip.height } - - def latency_to_finality(self): - safe_messages = self.safe_messages() - - if not any(safe_messages): - return None - - individual_latency = [ - self.global_view.when_finalized[message] - self.global_view.when_added[message] - for message in safe_messages - ] - - return statistics.mean(individual_latency) - - def orphan_rate(self): - num_unsafe_messages = self.num_unsafe_messages() - num_safe_messages = self.num_safe_messages() - if num_unsafe_messages + num_safe_messages == 0: - return 0 - return float(num_unsafe_messages) / (num_unsafe_messages + num_safe_messages) diff --git a/simulations/experiment.py b/simulations/experiment.py index 28c4171..d50e367 100644 --- a/simulations/experiment.py +++ b/simulations/experiment.py @@ -21,6 +21,7 @@ def __init__( validator_set_generator, msg_mode, protocol, + network_class, sim_rounds, sim_report_interval ): @@ -30,6 +31,7 @@ def __init__( self.validator_set_generator = validator_set_generator self.msg_mode = msg_mode self.protocol = protocol + self.network_class = network_class self.sim_rounds = sim_rounds self.sim_report_interval = sim_report_interval self.intervals = int(self.sim_rounds / self.sim_report_interval) @@ -50,10 +52,12 @@ def run(self): def run_sim(self, sim_id): validator_set = self.validator_set_generator() + network = self.network_class(validator_set, self.protocol) runner = SimulationRunner( validator_set, message_maker(self.msg_mode), self.protocol, + network, total_rounds=self.sim_rounds, report_interval=self.sim_report_interval, display=False, @@ -96,7 +100,7 @@ def _collect_data(self, runner, sim_id, interval): analyzer = Analyzer(runner) self.analyzer_data['simulation_data'][sim_id][interval] = { - d: getattr(analyzer, d)() + d: getattr(analyzer, d) for d in self.data } diff --git a/simulations/simulation_runner.py b/simulations/simulation_runner.py index 556623f..1bd88bb 100644 --- a/simulations/simulation_runner.py +++ b/simulations/simulation_runner.py @@ -1,7 +1,5 @@ import sys -from casper.network import Network - class SimulationRunner: def __init__( @@ -9,6 +7,7 @@ def __init__( validator_set, msg_gen, protocol, + network, total_rounds, report_interval, display, @@ -29,8 +28,7 @@ def __init__( else: self.report_interval = 1 - self.network = Network(validator_set, protocol) - self.network.random_initialization() + self.network = network self.plot_tool = protocol.PlotTool(display, save, self.network.global_view, validator_set) self.plot_tool.plot() @@ -38,47 +36,56 @@ def __init__( def run(self): """ run simulation total_rounds if specified otherwise, run indefinitely """ + self._send_initial_messages() + while self.round < self.total_rounds: self.step() if self.save: + print("making gif") self.plot_tool.make_gif() def step(self): """ run one round of the simulation """ + """ this becomes, who is going to make a message and send to the network """ + """ rather than what explicit paths happen """ self.round += 1 - message_paths = self.msg_gen(self.validator_set) - - affected_validators = {j for i, j in message_paths} + received_messages = self._receive_messages() + self._update_safe_estimates(received_messages.keys()) - sent_messages = self._send_messages_along_paths(message_paths) - new_messages = self._make_new_messages(affected_validators) - self._check_for_new_safety(affected_validators) + new_messages = self._generate_new_messages() - self.plot_tool.update(message_paths, sent_messages, new_messages) + self.plot_tool.update(new_messages) if self.round % self.report_interval == self.report_interval - 1: self.plot_tool.plot() - def _send_messages_along_paths(self, message_paths): - sent_messages = {} - # Send most recent message of sender to receive - for sender, receiver in message_paths: - message = sender.my_latest_message() - self.network.propagate_message_to_validator(message, receiver) - sent_messages[sender] = message + self.network.advance_time() - return sent_messages - - def _make_new_messages(self, validators): - messages = {} + def _generate_new_messages(self): + validators = self.msg_gen(self.validator_set) + new_messages = [] + for validator in validators: + message = validator.make_new_message() + self.network.send_to_all(message) + new_messages.append(message) + return new_messages + + def _receive_messages(self): + received_messages = {} + for validator in self.validator_set: + messages = self.network.receive_all_available(validator) + if messages: + validator.receive_messages(set(messages)) + received_messages[validator] = messages + return received_messages + + def _update_safe_estimates(self, validators): for validator in validators: - message = self.network.get_message_from_validator(validator) - messages[validator] = message - - return messages - - def _check_for_new_safety(self, affected_validators): - for validator in affected_validators: validator.update_safe_estimates() - self.network.global_view.update_safe_estimates(self.validator_set) + + def _send_initial_messages(self): + """ ensures that initial messages are attempted to be propogated. + requirement for any protocol where initial message is not shared """ + for validator in self.validator_set: + self.network.send_to_all(validator.initial_message) diff --git a/simulations/testing_language.py b/simulations/testing_language.py deleted file mode 100644 index 3522b62..0000000 --- a/simulations/testing_language.py +++ /dev/null @@ -1,203 +0,0 @@ -"""The testing language module ... """ -import re -import random as r - -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.network import Network -from casper.plot_tool import PlotTool -from casper.safety_oracles.clique_oracle import CliqueOracle -from casper.validator_set import ValidatorSet -import casper.utils as utils - - -class TestLangCBC(object): - """Allows testing of simulation scenarios with small testing language.""" - - # Signal to py.test that TestLangCBC should not be discovered. - __test__ = False - - TOKEN_PATTERN = '([A-Za-z]*)([0-9]*)([-]*)([A-Za-z0-9]*)' - - def __init__(self, val_weights, protocol=BlockchainProtocol, display=False): - - self.validator_set = ValidatorSet(val_weights, protocol) - self.display = display - self.network = Network(self.validator_set, protocol) - - # This seems to be misnamed. Just generates starting blocks. - self.network.random_initialization() - - self.plot_tool = PlotTool(display, False, 's') - self.blocks = dict() - self.blockchain = [] - self.communications = [] - self.block_fault_tolerance = dict() - - # Register token handlers. - self.handlers = dict() - self.handlers['B'] = self.make_block - self.handlers['S'] = self.send_block - self.handlers['C'] = self.check_safety - self.handlers['U'] = self.no_safety - self.handlers['H'] = self.check_head_equals_block - self.handlers['RR'] = self.round_robin - self.handlers['R'] = self.report - - def _validate_validator(self, validator): - if validator not in self.validator_set: - raise ValueError('Validator {} does not exist'.format(validator)) - - def _validate_block_exists(self, block_name): - if block_name not in self.blocks: - raise ValueError('Block {} does not exist'.format(block_name)) - - def _validate_block_does_not_exist(self, block_name): - if block_name in self.blocks: - raise ValueError('Block {} already exists'.format(block_name)) - - def parse(self, test_string): - """Parse the test_string, and run the test""" - for token in test_string.split(' '): - letter, validator, dash, name = re.match(self.TOKEN_PATTERN, token).groups() - if letter+validator+dash+name != token: - raise ValueError("Bad token: %s" % token) - if validator != '': - try: - validator = self.validator_set.get_validator_by_name(int(validator)) - except KeyError: - raise ValueError("Validator {} does not exist".format(validator)) - - self.handlers[letter](validator, name) - - def send_block(self, validator, block_name): - """Send some validator a block.""" - self._validate_validator(validator) - self._validate_block_exists(block_name) - - block = self.blocks[block_name] - - if block in validator.view.messages: - raise Exception( - 'Validator {} has already seen block {}' - .format(validator, block_name) - ) - - self.network.propagate_message_to_validator(block, validator) - - def make_block(self, validator, block_name): - """Have some validator produce a block.""" - self._validate_validator(validator) - self._validate_block_does_not_exist(block_name) - - new_block = self.network.get_message_from_validator(validator) - - if new_block.estimate is not None: - self.blockchain.append([new_block, new_block.estimate]) - - self.blocks[block_name] = new_block - - def round_robin(self, validator, block_name): - """Have each validator create a block in a perfect round robin.""" - self._validate_validator(validator) - self._validate_block_does_not_exist(block_name) - - # start round robin at validator speicied by validator in args - validators = self.validator_set.sorted_by_name() - start_index = validators.index(validator) - validators = validators[start_index:] + validators[:start_index] - - for i in range(len(self.validator_set)): - if i == len(self.validator_set) - 1: - name = block_name - else: - name = r.random() - maker = validators[i] - receiver = validators[(i + 1) % len(validators)] - - self.make_block(maker, name) - self.send_block(receiver, name) - - def check_safety(self, validator, block_name): - """Check that some validator detects safety on a block.""" - self._validate_validator(validator) - self._validate_block_exists(block_name) - - block = self.blocks[block_name] - validator.update_safe_estimates() - - assert validator.view.last_finalized_block is None or \ - not block.conflicts_with(validator.view.last_finalized_block), \ - "Block {0} failed safety assert for validator-{1}".format(block_name, validator.name) - - def no_safety(self, validator, block_name): - """Check that some validator does not detect safety on a block.""" - self._validate_validator(validator) - self._validate_block_exists(block_name) - - block = self.blocks[block_name] - validator.update_safe_estimates() - - #NOTE: This should never fail - assert validator.view.last_finalized_block is None or \ - block.conflicts_with(validator.view.last_finalized_block), \ - "Block {} failed no-safety assert".format(block_name) - - def check_head_equals_block(self, validator, block_name): - """Check some validators forkchoice is the correct block.""" - self._validate_validator(validator) - self._validate_block_exists(block_name) - - block = self.blocks[block_name] - - head = validator.view.estimate() - - assert block == head, "Validator {} does not have " \ - "block {} at head".format(validator, block_name) - - def report(self, num, name): - """Display the view graph of the current global_view""" - assert num == name and num == '', "...no validator or number needed to report!" - - if not self.display: - return - - # Update the safe blocks! - tip = self.network.global_view.estimate() - while tip and self.block_fault_tolerance.get(tip, 0) != len(self.validator_set) - 1: - oracle = CliqueOracle(tip, self.network.global_view, self.validator_set) - fault_tolerance, num_node_ft = oracle.check_estimate_safety() - - if fault_tolerance > 0: - self.block_fault_tolerance[tip] = num_node_ft - - tip = tip.estimate - - edgelist = [] - - best_chain = utils.build_chain( - self.network.global_view.estimate(), - None - ) - edgelist.append(utils.edge(best_chain, 5, 'red', 'solid')) - - for validator in self.validator_set: - chain = utils.build_chain( - validator.my_latest_message(), - None - ) - edgelist.append(utils.edge(chain, 2, 'blue', 'solid')) - - edgelist.append(utils.edge(self.blockchain, 2, 'grey', 'solid')) - edgelist.append(utils.edge(self.communications, 1, 'black', 'dotted')) - - message_labels = {} - for block in self.network.global_view.messages: - message_labels[block] = block.sequence_number - - self.plot_tool.next_viewgraph( - self.network.global_view, - self.validator_set, - edges=edgelist, - message_colors=self.block_fault_tolerance, - message_labels=message_labels - ) diff --git a/simulations/utils.py b/simulations/utils.py index 363384b..7f6bc52 100644 --- a/simulations/utils.py +++ b/simulations/utils.py @@ -1,14 +1,36 @@ """The simulution utils module ... """ -import itertools import random as r -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.binary.binary_protocol import BinaryProtocol - +from casper.networks import ( + ConstantDelayNetwork, + GaussianDelayNetwork, + LinearDelayNetwork, + NoDelayNetwork, + StepNetwork +) +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.binary.binary_protocol import BinaryProtocol +from casper.protocols.integer.integer_protocol import IntegerProtocol +from casper.protocols.order.order_protocol import OrderProtocol +from casper.protocols.concurrent.concurrent_protocol import ConcurrentProtocol from casper.validator_set import ValidatorSet MESSAGE_MODES = ['rand', 'rrob', 'full', 'nofinal'] -PROTOCOLS = ['blockchain', 'binary'] +NETWORKS = ['no-delay', 'step', 'constant', 'linear', 'gaussian'] +PROTOCOLS = ['blockchain', 'binary', 'integer', 'order', 'concurrent'] + + +def select_network(network): + if network == 'no-delay': + return NoDelayNetwork + if network == 'constant': + return ConstantDelayNetwork + if network == 'step': + return StepNetwork + if network == 'linear': + return LinearDelayNetwork + if network == 'gaussian': + return GaussianDelayNetwork def select_protocol(protocol): @@ -16,6 +38,12 @@ def select_protocol(protocol): return BlockchainProtocol if protocol == 'binary': return BinaryProtocol + if protocol == 'order': + return OrderProtocol + if protocol == 'integer': + return IntegerProtocol + if protocol == 'concurrent': + return ConcurrentProtocol def message_maker(mode): @@ -24,27 +52,23 @@ def message_maker(mode): if mode == "rand": def random(validator_set, num_messages=1): - """Each round, some randomly selected validators propagate their most recent - message to other randomly selected validators, who then create new messages.""" - pairs = list(itertools.permutations(validator_set, 2)) - return r.sample(pairs, num_messages) + """Each round, some randomly selected validator makes a message""" + return r.sample(validator_set.validators, 1) + # pairs = list(itertools.permutations(validator_set, 2)) + # return r.sample(pairs, num_messages) return random if mode == "rrob": def round_robin(validator_set): - """Each round, the creator of the last round's block sends it to the next - receiver, who then creates a block.""" + """Each round, the next validator in a set order makes a message""" sorted_validators = validator_set.sorted_by_name() sender_index = round_robin.next_sender_index round_robin.next_sender_index = (sender_index + 1) % len(validator_set) - receiver_index = round_robin.next_sender_index + # receiver_index = round_robin.next_sender_index - return [[ - sorted_validators[sender_index], - sorted_validators[receiver_index] - ]] + return [sorted_validators[sender_index]] round_robin.next_sender_index = 0 return round_robin @@ -52,10 +76,8 @@ def round_robin(validator_set): if mode == "full": def full_propagation(validator_set): - """Each round, all validators receive all other validators previous - messages, and then all create messages.""" - pairs = list(itertools.permutations(validator_set, 2)) - return pairs + """Each round, all validators make all messages""" + return validator_set.validators return full_propagation diff --git a/state_languages/__init__.py b/state_languages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/state_languages/binary_test_lang.py b/state_languages/binary_test_lang.py new file mode 100644 index 0000000..de55452 --- /dev/null +++ b/state_languages/binary_test_lang.py @@ -0,0 +1,49 @@ +"""The testing language module ... """ +import re +import random as r + +from state_languages.state_language import StateLanguage +from casper.protocols.binary.binary_protocol import BinaryProtocol + + +class BinaryTestLang(StateLanguage): + """Allows testing of simulation scenarios with small testing language.""" + + # Signal to py.test that TestLangCBC should not be discovered. + __test__ = False + + def __init__(self, val_weights, display=False): + super().__init__(val_weights, BinaryProtocol, display) + + def check_estimate(self, validator, estimate): + """Check that a validators estimate is the correct""" + estimate = int(estimate) + assert estimate in {0, 1}, "estimate must be a bit" + + bit = validator.view.estimate() + + assert bit == estimate, "Validator {} does not have " \ + "estimate {}".format(validator, estimate) + + def check_safe(self, validator, estimate): + """Check that some validator is safe on the correct bit.""" + estimate = int(estimate) + assert estimate in {0, 1}, "estimate must be a bit" + + validator.update_safe_estimates() + + assert validator.view.last_finalized_estimate is not None and \ + validator.view.last_finalized_estimate.estimate == estimate, \ + "{0} failed safety assert for validator-{1}".format(estimate, validator.name) + + def check_unsafe(self, validator, estimate): + """Check that some validator is not safe on some integer.""" + estimate = int(estimate) + assert estimate in {0, 1}, "estimate must be a bit" + + validator.update_safe_estimates() + + # NOTE: This should never fail + assert validator.view.last_finalized_estimate is None or \ + validator.view.last_finalized_estimate.estimate != estimate, \ + "{0} failed no-safety assert for validator-{1}".format(estimate, validator) diff --git a/state_languages/blockchain_test_lang.py b/state_languages/blockchain_test_lang.py new file mode 100644 index 0000000..02d5328 --- /dev/null +++ b/state_languages/blockchain_test_lang.py @@ -0,0 +1,49 @@ +"""The testing language module ... """ +import re +import random as r + +from state_languages.state_language import StateLanguage +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol + + +class BlockchainTestLang(StateLanguage): + """Allows testing of simulation scenarios with small testing language.""" + + # Signal to py.test that TestLangCBC should not be discovered. + __test__ = False + + def __init__(self, val_weights, display=False): + super().__init__(val_weights, BlockchainProtocol, display) + + def check_estimate(self, validator, estimate): + """Check that a validators forkchoice is some block""" + self.require_message_exists(estimate) + + message = self.messages[estimate] + + head = validator.view.estimate() + + assert message == head, "Validator {} does not have " \ + "block {} at head".format(validator.name, estimate) + + def check_safe(self, validator, estimate): + """Check that some validator does not detect safety on a block.""" + self.require_message_exists(estimate) + + message = self.messages[estimate] + validator.update_safe_estimates() + + assert validator.view.last_finalized_block is None or \ + not message.conflicts_with(validator.view.last_finalized_block), \ + "Block {0} failed safety assert for validator-{1}".format(estimate, validator.name) + + def check_unsafe(self, validator, estimate): + """Must be implemented by child class""" + self.require_message_exists(estimate) + + message = self.messages[estimate] + validator.update_safe_estimates() + + assert validator.view.last_finalized_block is None or \ + message.conflicts_with(validator.view.last_finalized_block), \ + "Block {} failed no-safety assert".format(estimate) diff --git a/state_languages/integer_test_lang.py b/state_languages/integer_test_lang.py new file mode 100644 index 0000000..55d8e6f --- /dev/null +++ b/state_languages/integer_test_lang.py @@ -0,0 +1,45 @@ +"""The testing language module ... """ +import re +import random as r + +from state_languages.state_language import StateLanguage +from casper.protocols.integer.integer_protocol import IntegerProtocol + + +class IntegerTestLang(StateLanguage): + """Allows testing of simulation scenarios with small testing language.""" + + # Signal to py.test that TestLangCBC should not be discovered. + __test__ = False + + def __init__(self, val_weights, display=False): + super().__init__(val_weights, IntegerProtocol, display) + + def check_estimate(self, validator, estimate): + """Check that a validators estimate is the correct number""" + estimate = int(estimate) + + num = validator.view.estimate() + + assert num == estimate, "Validator {} does not have " \ + "estimate {}".format(validator, estimate) + + def check_safe(self, validator, estimate): + """Check that some validator is safe on the correct integer.""" + estimate = int(estimate) + + validator.update_safe_estimates() + + assert validator.view.last_finalized_estimate is not None and \ + validator.view.last_finalized_estimate.estimate == estimate, \ + "{0} failed safety assert for validator-{1}".format(estimate, validator.name) + + def check_unsafe(self, validator, estimate): + """Check that some validator is not safe on some integer.""" + estimate = int(estimate) + + validator.update_safe_estimates() + + assert validator.view.last_finalized_estimate is None or \ + validator.view.last_finalized_estimate.estimate != estimate, \ + "{0} failed no-safety assert for validator-{1}".format(estimate, validator) diff --git a/state_languages/state_language.py b/state_languages/state_language.py new file mode 100644 index 0000000..dc7f809 --- /dev/null +++ b/state_languages/state_language.py @@ -0,0 +1,185 @@ +"""The testing language module ... """ +import re +import random as r + +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol +from casper.networks import NoDelayNetwork +from casper.validator_set import ValidatorSet + + +class StateLanguage(object): + """Allows testing of simulation scenarios with small testing language.""" + + TOKEN_PATTERN = '([A-Za-z]*)([0-9]*)([-]*)([A-Za-z0-9]*)([\\{A-Za-z,}]*)' + + def __init__(self, val_weights, protocol=BlockchainProtocol, display=False): + + self.validator_set = ValidatorSet(val_weights, protocol) + self.network = NoDelayNetwork(self.validator_set, protocol) + + self.messages = dict() + + self.plot_tool = protocol.PlotTool( + display, + False, + self.network.global_view, + self.validator_set + ) + + # Register token handlers. + self.handlers = dict() + + self.register_handler('M', self.make_message) + self.register_handler('I', self.make_invalid) + self.register_handler('S', self.send_message) + self.register_handler('P', self.plot) + self.register_handler('SJ', self.send_and_justify) + self.register_handler('RR', self.round_robin) + self.register_handler('CE', self.check_estimate) + self.register_handler('CS', self.check_safe) + self.register_handler('CU', self.check_unsafe) + + def register_handler(self, token, function): + """Registers a function with a new token. Throws an error if already registered""" + if token in self.handlers: + raise KeyError('A function has been registered with that token') + + self.handlers[token] = function + + def make_message(self, validator, message_name, messages_to_hide=None): + """Have a validator generate a new message""" + self.require_message_not_exists(message_name) + + #NOTE: Once validators have the ability to lie about their view, hide messages_to_hide! + + new_message = validator.make_new_message() + self.network.global_view.add_messages( + set([new_message]) + ) + + self.plot_tool.update([new_message]) + + self.messages[message_name] = new_message + + def send_message(self, validator, message_name): + """Send a message to a specific validator""" + self.require_message_exists(message_name) + + message = self.messages[message_name] + + self._propagate_message_to_validator(validator, message) + + def make_invalid(self, validator, message_name): + """TODO: Implement this when validators can make/handle invalid messages""" + raise NotImplementedError + + def send_and_justify(self, validator, message_name): + self.require_message_exists(message_name) + + message = self.messages[message_name] + self._propagate_message_to_validator(validator, message) + + messages_to_send = self._messages_needed_to_justify(message, validator) + for message in messages_to_send: + self._propagate_message_to_validator(validator, message) + + assert self.messages[message_name].hash in validator.view.justified_messages + + def round_robin(self, validator, message_name): + """Have each validator create a message in a perfect round robin.""" + self.require_message_not_exists(message_name) + + # start round robin at validator specified by validator in args + validators = self.validator_set.sorted_by_name() + start_index = validators.index(validator) + validators = validators[start_index:] + validators[:start_index] + + for i in range(len(self.validator_set)): + if i == len(self.validator_set) - 1: + name = message_name + else: + name = r.random() + maker = validators[i] + receiver = validators[(i + 1) % len(validators)] + + self.make_message(maker, name) + self.send_and_justify(receiver, name) + + def plot(self): + """Display or save a viewgraph""" + self.plot_tool.plot() + + def check_estimate(self, validator, estimate): + """Must be implemented by child class""" + raise NotImplementedError + + def check_safe(self, validator, estimate): + """Must be implemented by child class""" + raise NotImplementedError + + def check_unsafe(self, validator, estimate): + """Must be implemented by child class""" + raise NotImplementedError + + def require_message_exists(self, message_name): + """Throws an error if message_name does not exist""" + if message_name not in self.messages: + raise ValueError('Block {} does not exist'.format(message_name)) + + def require_message_not_exists(self, message_name): + """Throws an error if message_name does not exist""" + if message_name in self.messages: + raise ValueError('Block {} already exists'.format(message_name)) + + def _propagate_message_to_validator(self, validator, message): + self.network.send(validator, message) + received_message = self.network.receive(validator) + if received_message: + validator.receive_messages(set([received_message])) + + def _messages_needed_to_justify(self, message, validator): + """Returns the set of messages needed to justify a message to a validator""" + messages_needed = set() + + current_message_hashes = set() + for message_hash in message.justification.values(): + if message_hash not in validator.view.pending_messages and \ + message_hash not in validator.view.justified_messages: + current_message_hashes.add(message_hash) + + while any(current_message_hashes): + next_hashes = set() + + for message_hash in current_message_hashes: + message = self.network.global_view.justified_messages[message_hash] + messages_needed.add(message) + + for other_hash in message.justification.values(): + if other_hash not in validator.view.pending_messages and \ + other_hash not in validator.view.justified_messages: + next_hashes.add(other_hash) + + current_message_hashes = next_hashes + + return messages_needed + + def parse(self, protocol_state_string): + """Parse the state string!""" + for token in protocol_state_string.split(): + letter, validator, message = self.parse_token(token) + + if letter == 'P': + self.plot() + else: + validator = self.validator_set.get_validator_by_name(int(validator)) + self.handlers[letter](validator, message) + + def parse_token(self, token): + letter, validator, dash, message, removed_message_names = re.match( + self.TOKEN_PATTERN, token + ).groups() + + if letter + validator + dash + message + removed_message_names != token: + raise ValueError("Bad token: %s" % token) + + return letter, validator, message diff --git a/tests/casper/blockchain/test_block.py b/tests/casper/blockchain/test_block.py deleted file mode 100644 index 55d93ed..0000000 --- a/tests/casper/blockchain/test_block.py +++ /dev/null @@ -1,117 +0,0 @@ -"""The block testing module ...""" -import copy - -import pytest - -from casper.blockchain.block import Block -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.justification import Justification -from casper.validator import Validator - -from simulations.testing_language import TestLangCBC - - -def test_equality_of_copies_off_genesis(validator): - block = Block(None, Justification(), validator) - - shallow_copy = copy.copy(block) - deep_copy = copy.deepcopy(block) - - assert block == shallow_copy - assert block == deep_copy - assert shallow_copy == deep_copy - - -def test_equality_of_copies_of_non_genesis(report): - test_string = "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D" - test_lang = TestLangCBC({0: 10, 1: 11}, BlockchainProtocol, report) - test_lang.parse(test_string) - - for block in test_lang.blocks: - shallow_copy = copy.copy(block) - deep_copy = copy.deepcopy(block) - - assert block == shallow_copy - assert block == deep_copy - assert shallow_copy == deep_copy - - -def test_non_equality_of_copies_off_genesis(): - validator_0 = Validator("v0", 10) - validator_1 = Validator("v1", 11) - - block_0 = Block(None, Justification(), validator_0) - block_1 = Block(None, Justification(), validator_1) - - assert block_0 != block_1 - - -def test_unique_block_creation_in_test_lang(report): - test_string = "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D" - test_lang = TestLangCBC({0: 10, 1: 11}, BlockchainProtocol, report) - test_lang.parse(test_string) - - num_equal = 0 - for block in test_lang.blocks: - for block1 in test_lang.blocks: - if block1 == block: - num_equal += 1 - continue - - assert block != block1 - - assert num_equal == len(test_lang.blocks) - - -def test_is_in_blockchain__separate_genesis(): - validator_0 = Validator("v0", 10) - validator_1 = Validator("v1", 11) - - block_0 = Block(None, Justification(), validator_0) - block_1 = Block(None, Justification(), validator_1) - - assert not block_0.is_in_blockchain(block_1) - assert not block_1.is_in_blockchain(block_0) - - -def test_is_in_blockchain__test_lang(report): - test_string = "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D" - test_lang = TestLangCBC({0: 11, 1: 10}, BlockchainProtocol, report) - test_lang.parse(test_string) - - prev = test_lang.blocks['A'] - for b in ['B', 'C', 'D']: - block = test_lang.blocks[b] - assert prev.is_in_blockchain(block) - assert not block.is_in_blockchain(prev) - - prev = block - - -@pytest.mark.parametrize( - 'test_string, weights, block_heights', - [ - ( - "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D", - {0: 11, 1: 10}, - {"A": 2, "B": 3, "C": 4, "D": 5} - ), - ( - "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D", - {0: 1, 1: 10}, - {"A": 2, "B": 2, "C": 3, "D": 4} - ), - ( - "B0-A S1-A B0-B S1-B B1-C S0-C S2-C B2-D S0-D B0-E B1-F S0-F H0-E", - {0: 11, 1: 10, 2: 500}, - {"A": 2, "B": 3, "C": 4, "D": 2, "E": 3, "F": 5} - ), - ] -) -def test_block_height(report, test_string, weights, block_heights): - test_lang = TestLangCBC(weights, BlockchainProtocol, report) - test_lang.parse(test_string) - - for block_name in block_heights: - block = test_lang.blocks[block_name] - assert block.height == block_heights[block_name] diff --git a/tests/casper/blockchain/test_blockchain_protocol.py b/tests/casper/blockchain/test_blockchain_protocol.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/casper/blockchain/test_blockchain_protocol.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/casper/blockchain/test_blockchain_view.py b/tests/casper/blockchain/test_blockchain_view.py deleted file mode 100644 index 4b16fa7..0000000 --- a/tests/casper/blockchain/test_blockchain_view.py +++ /dev/null @@ -1,48 +0,0 @@ -"""The BlockchainView testing module...""" -import pytest - -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from simulations.testing_language import TestLangCBC - - -@pytest.mark.parametrize( - 'weights, test_string, showed_message_names, new_message_names', - [ - ( - {0: 10, 1: 11}, - "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D", - ["C", "D"], - ["A", "B", "C", "D"] - ), - ( - {0: 10, 1: 11}, - "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D", - ["C"], - ["A", "B", "C"] - ), - ( - {0: 10, 1: 11, 2: 30}, - "B0-A S1-A B1-B S0-B B0-C S1-C B2-D S0-D S1-D", - ["D"], - ["D"] - ), - ( - {0: 10, 1: 11, 2: 500}, - "B0-A S1-A B1-B S0-B B0-C S2-C B2-D S0-D S1-D", - ["D"], - ["A", "B", "C", "D"] - ), - ] -) -def test_get_new_messages(weights, test_string, showed_message_names, new_message_names, report): - test_lang = TestLangCBC(weights, BlockchainProtocol, report) - - view = BlockchainProtocol.View() - # add the initial messages to the view - view.add_messages(test_lang.network.global_view.messages) - test_lang.parse(test_string) - - showed_messages = {test_lang.blocks[name] for name in showed_message_names} - new_messages = {test_lang.blocks[name] for name in new_message_names} - - assert view.get_new_messages(showed_messages) == new_messages diff --git a/tests/casper/conftest.py b/tests/casper/conftest.py new file mode 100644 index 0000000..3a3dc54 --- /dev/null +++ b/tests/casper/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from state_languages.blockchain_test_lang import BlockchainTestLang +from state_languages.integer_test_lang import IntegerTestLang +from state_languages.binary_test_lang import BinaryTestLang + +TEST_LANGS = [BlockchainTestLang, IntegerTestLang, BinaryTestLang] + + +@pytest.fixture(params=TEST_LANGS) +def test_lang_creator(request, report): + def creator(weights): + return request.param(weights, report) + return creator diff --git a/tests/casper/protocols/__init__.py b/tests/casper/protocols/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/casper/protocols/binary/conftest.py b/tests/casper/protocols/binary/conftest.py new file mode 100644 index 0000000..b26e6e2 --- /dev/null +++ b/tests/casper/protocols/binary/conftest.py @@ -0,0 +1,46 @@ +import random +import pytest + +from state_languages.binary_test_lang import BinaryTestLang +from casper.protocols.binary.binary_protocol import BinaryProtocol + + +@pytest.fixture +def binary_lang(report, test_weight): + return BinaryTestLang(test_weight, report) + + +@pytest.fixture +def binary_lang_runner(report): + def runner(weights, test_string): + BinaryTestLang(weights, report).parse(test_string) + return runner + + +@pytest.fixture +def binary_lang_creator(report): + def creator(weights): + return BinaryTestLang(weights, report) + return creator + + +@pytest.fixture +def binary_validator_set(generate_validator_set): + return generate_validator_set(BinaryProtocol) + + +@pytest.fixture +def binary_validator(binary_validator_set): + return random.choice(list(binary_validator_set)) + + +@pytest.fixture +def bet(empty_just, binary_validator): + return BinaryProtocol.Message(0, empty_just, binary_validator, 0, 0) + + +@pytest.fixture +def create_bet(empty_just, binary_validator): + def c_bet(estimate): + return BinaryProtocol.Message(estimate, empty_just, binary_validator, 0, 0) + return c_bet diff --git a/tests/casper/protocols/binary/test_binary_estimator.py b/tests/casper/protocols/binary/test_binary_estimator.py new file mode 100644 index 0000000..49b3e4d --- /dev/null +++ b/tests/casper/protocols/binary/test_binary_estimator.py @@ -0,0 +1,45 @@ +"""The binary estimator testing module tests the binary estimator""" +import pytest + +from casper.protocols.binary.binary_protocol import BinaryProtocol +from casper.protocols.binary.bet import Bet +from casper.validator_set import ValidatorSet +import casper.protocols.binary.binary_estimator as estimator + + +@pytest.mark.parametrize( + 'weights, latest_estimates, estimate', + [ + ( + {0: 1}, + {0: 1}, + 1 + ), + ( + {0: 5, 1: 6, 2: 7}, + {0: 0, 1: 0, 2: 1}, + 0 + ), + ( + {0: 5, 1: 10, 2: 14}, + {0: 1, 1: 1, 2: 0}, + 1 + ), + ( + {0: 5, 1: 11}, + {0: 0, 1: 1}, + 1 + ), + ] +) +def test_estimator_picks_correct_estimate(weights, latest_estimates, estimate, empty_just): + validator_set = ValidatorSet(weights, BinaryProtocol) + + latest_messages = dict() + for val_name in latest_estimates: + validator = validator_set.get_validator_by_name(val_name) + latest_messages[validator] = Bet( + latest_estimates[val_name], empty_just, validator, 1, 1 + ) + + assert estimate == estimator.get_estimate_from_latest_messages(latest_messages) diff --git a/tests/casper/protocols/binary/test_binary_message.py b/tests/casper/protocols/binary/test_binary_message.py new file mode 100644 index 0000000..6507d55 --- /dev/null +++ b/tests/casper/protocols/binary/test_binary_message.py @@ -0,0 +1,40 @@ +import pytest + +from casper.protocols.binary.bet import Bet + + +@pytest.mark.parametrize( + 'estimate, is_valid', + [ + (0, True), + (1, True), + (True, True), + (False, True), + (-1, False), + (2, False), + ((0, 1), False), + (None, False), + ] +) +def test_accepts_valid_estimates(estimate, is_valid): + assert Bet.is_valid_estimate(estimate) == is_valid + + +@pytest.mark.parametrize( + 'estimate_one, estimate_two, conflicts', + [ + (0, 0, False), + (1, 1, False), + (False, 0, False), + (True, 1, False), + (1, 0, True), + (0, 1, True), + (True, 0, True), + (1, False, True), + ] +) +def test_conflicts_with(estimate_one, estimate_two, conflicts, create_bet): + bet_one = create_bet(estimate_one) + bet_two = create_bet(estimate_two) + + assert bet_one.conflicts_with(bet_two) == conflicts diff --git a/tests/casper/protocols/binary/test_binary_test_lang.py b/tests/casper/protocols/binary/test_binary_test_lang.py new file mode 100644 index 0000000..937b2aa --- /dev/null +++ b/tests/casper/protocols/binary/test_binary_test_lang.py @@ -0,0 +1,111 @@ +"""The language testing module ... """ +import pytest + +from state_languages.binary_test_lang import BinaryTestLang +from casper.network import Network +from casper.validator_set import ValidatorSet + + +def test_init_creates_state_lang(test_weight): + binary_lang = BinaryTestLang(test_weight, False) + + binary_lang.messages + binary_lang.plot_tool + + assert isinstance(binary_lang.network, Network) + assert isinstance(binary_lang.validator_set, ValidatorSet) + + assert len(binary_lang.validator_set) == len(test_weight) + + # should only have seen their initial message + for validator in binary_lang.validator_set: + assert len(validator.view.justified_messages) == 1 + + +@pytest.mark.parametrize( + 'test_string, error', + [ + ('CE0-2', AssertionError), + ('CS0-2', AssertionError), + ('CU0-2', AssertionError), + ('CE0-A', ValueError), + ('CS0-A', ValueError), + ('CU0-A', ValueError), + ] +) +def test_only_binary_estimates(test_string, error, binary_lang): + binary_lang.parse('M0-A') + + with pytest.raises(error): + binary_lang.parse(test_string) + + +def test_check_estimate_passes_on_valid_assertions(binary_lang): + binary_lang.parse('M0-A S1-A S2-A S3-A S4-A') + + current_estimates = dict() + for validator in binary_lang.validator_set: + current_estimates[validator] = validator.estimate() + + check_estimate = '' + for validator in binary_lang.validator_set: + check_estimate += 'CE' + str(validator.name) + '-' + str(current_estimates[validator]) + ' ' + check_estimate = check_estimate[:-1] + + binary_lang.parse(check_estimate) + + +@pytest.mark.parametrize( + 'test_string', + [ + ('M0-A CE0-0 CE0-1'), + ('RR0-A RR0-B CE0-0 CE0-1'), + ('M0-A CS0-0 CS0-1'), + ('RR0-A RR0-B CS1-0 CS1-1'), + ] +) +def test_checks_fails_on_invalid_assertions(test_string, binary_lang): + with pytest.raises(AssertionError): + binary_lang.parse(test_string) + + +def test_check_safe_passes_on_valid_assertions(binary_lang): + binary_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = binary_lang.network.global_view.estimate() + + check_safe = '' + for validator in binary_lang.validator_set: + check_safe += 'CS' + str(validator.name) + '-' + str(current_estimate) + ' ' + check_safe = check_safe[:-1] + + binary_lang.parse(check_safe) + + +def test_check_unsafe_passes_on_valid_assertions(binary_lang): + binary_lang.parse('M0-A S1-A S2-A S3-A S4-A CU0-0 CU0-1 CU1-0 CU1-1 CU2-0 CU2-1 CU3-0 CU3-1 CU4-0 CU4-1') + + for validator in binary_lang.validator_set: + assert validator.view.last_finalized_estimate is None + + +def test_check_unsafe_passes_on_valid_assertions_rr(binary_lang): + binary_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = binary_lang.network.global_view.estimate() + + check_unsafe = '' + for validator in binary_lang.validator_set: + check_unsafe += 'CU' + str(validator.name) + '-' + str(1 - current_estimate) + ' ' + check_unsafe = check_unsafe[:-1] + + binary_lang.parse(check_unsafe) + + +def test_check_unsafe_fails_on_invalid_assertions(binary_lang): + binary_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = binary_lang.network.global_view.estimate() + + with pytest.raises(AssertionError): + binary_lang.parse('CU0-' + str(current_estimate)) diff --git a/tests/casper/protocols/binary/test_binary_view.py b/tests/casper/protocols/binary/test_binary_view.py new file mode 100644 index 0000000..5f3521a --- /dev/null +++ b/tests/casper/protocols/binary/test_binary_view.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.mark.skip(reason="inital messages not yet specified") +def test_update_safe_estimates(weights, test_string, finalized, binary_lang_creator): + binary_lang = binary_lang_creator(weights) + binary_lang.parse(test_string) + + validator = binary_lang.validator_set.get_validator_by_name(0) + + validator.view.update_safe_estimates(binary_lang.validator_set) + + assert validator.view.last_finalized_estimate.estimate == finalized diff --git a/tests/casper/protocols/blockchain/__init__.py b/tests/casper/protocols/blockchain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/casper/protocols/blockchain/conftest.py b/tests/casper/protocols/blockchain/conftest.py new file mode 100644 index 0000000..c1590d0 --- /dev/null +++ b/tests/casper/protocols/blockchain/conftest.py @@ -0,0 +1,46 @@ +import random +import pytest + +from state_languages.blockchain_test_lang import BlockchainTestLang +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol + + +@pytest.fixture +def blockchain_lang(report, test_weight): + return BlockchainTestLang(test_weight, report) + + +@pytest.fixture +def blockchain_lang_runner(report): + def runner(weights, test_string): + BlockchainTestLang(weights, report).parse(test_string) + return runner + + +@pytest.fixture +def blockchain_lang_creator(report): + def creator(weights): + return BlockchainTestLang(weights, report) + return creator + + +@pytest.fixture +def blockchain_validator_set(generate_validator_set): + return generate_validator_set(BlockchainProtocol) + + +@pytest.fixture +def blockchain_validator(blockchain_validator_set): + return random.choice(list(blockchain_validator_set)) + + +@pytest.fixture +def block(empty_just, blockchain_validator): + return BlockchainProtocol.Message(None, empty_just, blockchain_validator, 0, 0) + + +@pytest.fixture +def create_block(empty_just, blockchain_validator): + def c_block(estimate): + return BlockchainProtocol.Message(estimate, empty_just, blockchain_validator, 0, 0) + return c_block diff --git a/tests/casper/protocols/blockchain/test_blockchain_estimator.py b/tests/casper/protocols/blockchain/test_blockchain_estimator.py new file mode 100644 index 0000000..e107272 --- /dev/null +++ b/tests/casper/protocols/blockchain/test_blockchain_estimator.py @@ -0,0 +1,97 @@ +"""The forkchoice testing module ... """ +import pytest +import random as r + +import casper.protocols.blockchain.forkchoice as forkchoice + + +def test_single_validator_correct_forkchoice(blockchain_lang_runner): + """ This tests that a single validator remains on their own chain """ + test_string = "" + for i in range(100): + test_string += "M0-" + str(i) + " " + "CE0-" + str(i) + " " + test_string = test_string[:-1] + + blockchain_lang_runner({0: 10}, test_string) + + +def test_two_validators_round_robin_forkchoice(blockchain_lang_runner): + test_string = "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D P" + blockchain_lang_runner({0: 10, 1: 11}, test_string) + + +def test_many_val_round_robin_forkchoice(blockchain_lang_runner): + """ + Tests that during a perfect round robin, + validators choose the one chain as their fork choice + """ + test_string = "" + for i in range(25): + test_string += "M" + str(i % 10) + "-" + str(i) + " " \ + + "SJ" + str((i + 1) % 10) + "-" + str(i) + " " \ + + "CE" + str((i + 1) % 10) + "-" + str(i) + " " + test_string = test_string[:-1] + + blockchain_lang_runner( + {i: 10 - i + r.random() for i in range(10)}, + test_string + ) + + +def test_fail_on_tie(blockchain_lang_runner): + """ + Tests that if there are two subsets of the validator + set with the same weight, the forkchoice fails + """ + test_string = "M1-A SJ0-A M0-B SJ1-B SJ2-A M2-C SJ1-C CE1-C" + with pytest.raises(AssertionError): + blockchain_lang_runner({0: 5, 1: 6, 2: 5}, test_string) + + +def test_ignore_zero_weight_validator(blockchain_lang_runner): + """ + Tests that a validator with zero weight + will not affect the forkchoice + """ + test_string = "M0-A SJ1-A M1-B SJ0-B CE1-A CE0-A" + blockchain_lang_runner({0: 1, 1: 0}, test_string) + + +def test_ignore_zero_weight_block(blockchain_lang_runner): + """ Tests that the forkchoice ignores zero weight blocks """ + # for more info about test, see + # https://gist.github.com/naterush/8d8f6ec3509f50939d7911d608f912f4 + test_string = ( + "M0-A1 M0-A2 CE0-A2 M1-B1 M1-B2 SJ3-B2 M3-D1 CE3-D1 " + "SJ3-A2 CE3-A2 M3-D2 SJ2-B1 CE2-B1 M2-C1 CE2-C1 SJ1-D1 " + "SJ1-D2 SJ1-C1 CE1-B2" + ) + blockchain_lang_runner({0: 10, 1: 9, 2: 8, 3: 0.5}, test_string) + + +def test_reverse_message_arrival_order_forkchoice(blockchain_lang_runner): + test_string = ( + "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D M1-E SJ0-E " + "SJ2-E CE2-E SJ3-A SJ3-B SJ3-C SJ3-D SJ3-E CE3-E" + ) + blockchain_lang_runner({0: 5, 1: 6, 2: 7, 3: 8.1}, test_string) + + +@pytest.mark.parametrize( + 'weights, expected', + [ + ({i: i for i in range(10)}, {9}), + ({i: 9 - i for i in range(10)}, {0}), + ({i: i % 5 for i in range(10)}, {4, 9}), + ({i: 10 for i in range(10)}, {i for i in range(10)}), + ({}, ValueError), + ({i: 0 for i in range(10)}, AssertionError), + ] +) +def test_max_weight_indexes(weights, expected): + if isinstance(expected, type) and issubclass(expected, Exception): + with pytest.raises(expected): + forkchoice.get_max_weight_indexes(weights) + return + + assert forkchoice.get_max_weight_indexes(weights) == expected diff --git a/tests/casper/protocols/blockchain/test_blockchain_message.py b/tests/casper/protocols/blockchain/test_blockchain_message.py new file mode 100644 index 0000000..fea39c3 --- /dev/null +++ b/tests/casper/protocols/blockchain/test_blockchain_message.py @@ -0,0 +1,146 @@ +"""The block testing module ...""" +import copy + +import pytest + +from casper.protocols.blockchain.block import Block +from casper.validator import Validator + +from state_languages.blockchain_test_lang import BlockchainTestLang + + +@pytest.mark.parametrize( + 'estimate, is_valid', + [ + (None, True), + ('block', False), + (0, False), + (True, False), + ] +) +def test_accepts_valid_estimates(estimate, is_valid, block): + if estimate == 'block': + Block.is_valid_estimate(block) == is_valid + return + + assert Block.is_valid_estimate(estimate) == is_valid + + +@pytest.mark.parametrize( + 'estimate_one, estimate_two, conflicts', + [ + (None, 'prev', False), + (None, None, True), + ] +) +def test_conflicts_with(estimate_one, estimate_two, conflicts, create_block): + bet_one = create_block(estimate_one) + if estimate_two == 'prev': + estimate_two = bet_one + bet_two = create_block(estimate_two) + + assert bet_one.conflicts_with(bet_two) == conflicts + + +def test_equality_of_copies_off_genesis(validator, empty_just): + block = Block(None, empty_just, validator, 0, 0) + + shallow_copy = copy.copy(block) + + assert block == shallow_copy + + +@pytest.mark.skip(reason="current deepcopy bug") +def test_equality_of_copies_of_non_genesis(report): + test_string = "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D" + test_lang = BlockchainTestLang({0: 10, 1: 11}, report) + test_lang.parse(test_string) + + for message in test_lang.messages.values(): + shallow_copy = copy.copy(message) + deep_copy = copy.deepcopy(message) + + assert message == shallow_copy + assert message == deep_copy + assert shallow_copy == deep_copy + + +def test_non_equality_of_copies_off_genesis(empty_just): + validator_0 = Validator("v0", 10) + validator_1 = Validator("v1", 11) + + block_0 = Block(None, empty_just, validator_0, 0, 0) + block_1 = Block(None, empty_just, validator_1, 0, 0) + + assert block_0 != block_1 + + +def test_unique_block_creation_in_test_lang(report): + test_string = "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D" + test_lang = BlockchainTestLang({0: 10, 1: 11}, report) + test_lang.parse(test_string) + + num_equal = 0 + for message1 in test_lang.messages: + for message2 in test_lang.messages: + if message1 == message2: + num_equal += 1 + continue + + assert message1 != message2 + + assert num_equal == len(test_lang.messages) + + +def test_is_in_blockchain__separate_genesis(empty_just): + validator_0 = Validator("v0", 10) + validator_1 = Validator("v1", 11) + + block_0 = Block(None, empty_just, validator_0, 0, 0) + block_1 = Block(None, empty_just, validator_1, 0, 0) + + assert not block_0.is_in_blockchain(block_1) + assert not block_1.is_in_blockchain(block_0) + + +def test_is_in_blockchain__test_lang(report): + test_string = "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D" + test_lang = BlockchainTestLang({0: 11, 1: 10}, report) + test_lang.parse(test_string) + + prev = test_lang.messages['A'] + for b in ['B', 'C', 'D']: + block = test_lang.messages[b] + assert prev.is_in_blockchain(block) + assert not block.is_in_blockchain(prev) + + prev = block + + +@pytest.mark.parametrize( + 'test_string, weights, block_heights', + [ + ( + "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D", + {0: 11, 1: 10}, + {"A": 2, "B": 3, "C": 4, "D": 5} + ), + ( + "M0-A SJ1-A M1-B SJ0-B M0-C SJ1-C M1-D SJ0-D CE0-D", + {0: 1, 1: 10}, + {"A": 2, "B": 3, "C": 4, "D": 5} + ), + ( + "M0-A SJ1-A M0-B SJ1-B M1-C SJ0-C SJ2-C M2-D SJ0-D M0-E M1-F SJ0-F CE0-E", + {0: 11, 1: 10, 2: 500}, + {"A": 2, "B": 3, "C": 4, "D": 5, "E": 6, "F": 5} + ), + ] +) +def test_block_height(blockchain_lang_creator, test_string, weights, block_heights): + test_lang = blockchain_lang_creator(weights) + test_lang.parse(test_string) + + for block_name in block_heights: + block = test_lang.messages[block_name] + assert block.height == block_heights[block_name] diff --git a/tests/casper/protocols/blockchain/test_blockchain_test_lang.py b/tests/casper/protocols/blockchain/test_blockchain_test_lang.py new file mode 100644 index 0000000..3bfa54c --- /dev/null +++ b/tests/casper/protocols/blockchain/test_blockchain_test_lang.py @@ -0,0 +1,70 @@ +"""The language testing module ... """ +import pytest + +from state_languages.blockchain_test_lang import BlockchainTestLang +from casper.network import Network +from casper.validator_set import ValidatorSet + + +def test_init_creates_state_lang(test_weight): + blockchain_lang = BlockchainTestLang(test_weight, False) + + blockchain_lang.messages + blockchain_lang.plot_tool + + assert isinstance(blockchain_lang.network, Network) + assert isinstance(blockchain_lang.validator_set, ValidatorSet) + + assert len(blockchain_lang.validator_set) == len(test_weight) + + # should only have seen the genesis block + for validator in blockchain_lang.validator_set: + assert len(validator.view.justified_messages) == 1 + + +def test_check_estimate_passes_on_valid_assertions(blockchain_lang): + blockchain_lang.parse('M0-A S1-A S2-A S3-A S4-A CE0-A CE1-A CE2-A CE3-A CE4-A') + + forkchoice = blockchain_lang.messages['A'] + + for validator in blockchain_lang.validator_set: + assert validator.estimate() == forkchoice + + +@pytest.mark.parametrize( + 'test_string', + [ + ('M0-A S1-A M1-B CE1-A'), + ('M0-A M1-B CE1-A'), + ('M0-A M0-B S1-B CE1-B'), + ('M0-A S1-A S2-A S3-A S4-A CE0-A CE1-A CE2-A CE3-A CE4-A M0-B CE0-A'), + ('M0-A CS0-A'), + ('M0-A CS1-A'), + ('M0-A S1-A S2-A S3-A S4-A CS4-A'), + ('RR0-A CS0-A'), + ('RR0-A RR0-B RR0-C RR0-D RR0-E RR0-F CS1-F'), + ('RR0-A RR0-B RR0-C RR0-D CU0-A'), + ('M0-A RR0-B RR0-C RR0-D CE1-A CS1-A CU0-A'), + ] +) +def test_checks_fails_on_invalid_assertions(test_string, blockchain_lang): + with pytest.raises(AssertionError): + blockchain_lang.parse(test_string) + + +def test_check_safe_passes_on_valid_assertions(blockchain_lang): + blockchain_lang.parse('RR0-A RR0-B RR0-C RR0-D CS0-A CS1-A CS2-A CS3-A CS4-A') + + safe_block = blockchain_lang.messages['A'] + + for validator in blockchain_lang.validator_set: + assert not safe_block.conflicts_with(validator.view.last_finalized_block) + + +def test_check_unsafe_passes_on_valid_assertions(blockchain_lang): + blockchain_lang.parse('M0-A CU0-A') + + not_safe_block = blockchain_lang.messages['A'] + + for validator in blockchain_lang.validator_set: + assert not_safe_block.conflicts_with(validator.view.last_finalized_block) diff --git a/tests/casper/protocols/blockchain/test_blockchain_view.py b/tests/casper/protocols/blockchain/test_blockchain_view.py new file mode 100644 index 0000000..6a589ba --- /dev/null +++ b/tests/casper/protocols/blockchain/test_blockchain_view.py @@ -0,0 +1,30 @@ +"""The BlockchainView testing module...""" +import pytest + +@pytest.mark.parametrize( + 'test_string, children', + [ + ( + "M0-A S1-A M1-B S0-B", + {'A': ['B']} + ), + ( + "M0-A S1-A S2-A S3-A S4-A M1-B M2-C M3-D M4-E S0-B S0-C S0-D S0-E", + {'A': ['B', 'C', 'D', 'E']} + ), + ( + "M0-A S1-A S2-A M1-B M2-C S0-B S0-C M0-D", + {'A': ['B', 'C'], 'B': ['D']} + ), + ] +) +def test_update_protocol_specific_view(test_string, children, blockchain_lang): + blockchain_lang.parse(test_string) + + validator = blockchain_lang.validator_set.get_validator_by_name(0) + + for block_name in children: + block = blockchain_lang.messages[block_name] + + for child_name in children[block_name]: + assert blockchain_lang.messages[child_name] in validator.view.children[block] diff --git a/tests/casper/protocols/concurrent/__init__.py b/tests/casper/protocols/concurrent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/casper/protocols/concurrent/conftest.py b/tests/casper/protocols/concurrent/conftest.py new file mode 100644 index 0000000..9a97796 --- /dev/null +++ b/tests/casper/protocols/concurrent/conftest.py @@ -0,0 +1,30 @@ +import random +import pytest + +from casper.protocols.concurrent.concurrent_protocol import ConcurrentProtocol + + +@pytest.fixture +def concurrent_validator_set(generate_validator_set): + return generate_validator_set(ConcurrentProtocol) + + +@pytest.fixture +def concurrent_validator(concurrent_validator_set): + return random.choice(list(concurrent_validator_set)) + + +@pytest.fixture +def empty_concurrent_estimate(): + return {'blocks': {None}, 'inputs': [], 'outputs': []} + +@pytest.fixture +def block(empty_just, concurrent_validator, empty_concurrent_estimate): + return ConcurrentProtocol.Message(empty_concurrent_estimate, empty_just, concurrent_validator, 0, 0) + + +@pytest.fixture +def create_block(empty_just, concurrent_validator): + def c_block(estimate): + return ConcurrentProtocol.Message(estimate, empty_just, concurrent_validator, 0, 0) + return c_block diff --git a/tests/casper/protocols/concurrent/test_concurrent_estimator.py b/tests/casper/protocols/concurrent/test_concurrent_estimator.py new file mode 100644 index 0000000..77389dc --- /dev/null +++ b/tests/casper/protocols/concurrent/test_concurrent_estimator.py @@ -0,0 +1,7 @@ +"""The forkchoice testing module ... """ +import pytest +import random as r + +import casper.protocols.concurrent.forkchoice as forkchoice + +# TODO: Test once there is a complete concurrent testing language diff --git a/tests/casper/protocols/concurrent/test_concurrent_message.py b/tests/casper/protocols/concurrent/test_concurrent_message.py new file mode 100644 index 0000000..438e622 --- /dev/null +++ b/tests/casper/protocols/concurrent/test_concurrent_message.py @@ -0,0 +1,28 @@ +"""The block testing module ...""" +import copy + +import pytest + +from casper.protocols.concurrent.block import Block +from casper.protocols.concurrent.concurrent_protocol import ConcurrentProtocol + + +@pytest.mark.parametrize( + 'estimate, is_valid', + [ + ({'blocks': {None}, 'inputs': [], 'outputs': []}, True), + ({'blocks': {None}, 'inputs': [1, 2, 3], 'outputs': [4, 5, 6]}, True), + ({'blocks': {None}, 'inputs': []}, False), + ({'blocks': {None}, 'outputs': []}, False), + ({'blocks': {None}}, False), + ({'blocks': [], 'inputs': [], 'outputs': []}, False), + ({'inputs': [], 'outputs': []}, False), + (0, False), + (True, False), + ] +) +def test_accepts_valid_estimates(estimate, is_valid, block): + assert Block.is_valid_estimate(estimate) == is_valid + + +# TODO: Test once there is a complete concurrent testing language diff --git a/tests/casper/protocols/concurrent/test_concurrent_test_lang.py b/tests/casper/protocols/concurrent/test_concurrent_test_lang.py new file mode 100644 index 0000000..3bc1d55 --- /dev/null +++ b/tests/casper/protocols/concurrent/test_concurrent_test_lang.py @@ -0,0 +1,4 @@ +"""The language testing module ... """ +import pytest + +# TODO: Test once there is a complete concurrent testing language diff --git a/tests/casper/protocols/concurrent/test_concurrent_view.py b/tests/casper/protocols/concurrent/test_concurrent_view.py new file mode 100644 index 0000000..5826114 --- /dev/null +++ b/tests/casper/protocols/concurrent/test_concurrent_view.py @@ -0,0 +1,4 @@ +"""The concurrent view testing module...""" +import pytest + +# TODO: Test once there is a complete concurrent testing language diff --git a/tests/casper/protocols/integer/__init__.py b/tests/casper/protocols/integer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/casper/protocols/integer/conftest.py b/tests/casper/protocols/integer/conftest.py new file mode 100644 index 0000000..a2efff9 --- /dev/null +++ b/tests/casper/protocols/integer/conftest.py @@ -0,0 +1,46 @@ +import random +import pytest + +from state_languages.integer_test_lang import IntegerTestLang +from casper.protocols.integer.integer_protocol import IntegerProtocol + + +@pytest.fixture +def integer_lang(report, test_weight): + return IntegerTestLang(test_weight, report) + + +@pytest.fixture +def integer_lang_runner(report): + def runner(weights, test_string): + IntegerTestLang(weights, report).parse(test_string) + return runner + + +@pytest.fixture +def integer_lang_creator(report): + def creator(weights): + return IntegerTestLang(weights, report) + return creator + + +@pytest.fixture +def integer_validator_set(generate_validator_set): + return generate_validator_set(IntegerProtocol) + + +@pytest.fixture +def integer_validator(integer_validator_set): + return random.choice(list(integer_validator_set)) + + +@pytest.fixture +def bet(empty_just, integer_validator): + return IntegerProtocol.Message(0, empty_just, integer_validator, 0, 0) + + +@pytest.fixture +def create_bet(empty_just, integer_validator): + def c_bet(estimate): + return IntegerProtocol.Message(estimate, empty_just, integer_validator, 0, 0) + return c_bet diff --git a/tests/casper/protocols/integer/test_integer_estimator.py b/tests/casper/protocols/integer/test_integer_estimator.py new file mode 100644 index 0000000..5761650 --- /dev/null +++ b/tests/casper/protocols/integer/test_integer_estimator.py @@ -0,0 +1,55 @@ +"""The block testing module ...""" +import pytest + +from casper.protocols.integer.integer_protocol import IntegerProtocol +from casper.protocols.integer.bet import Bet +from casper.validator_set import ValidatorSet +import casper.protocols.integer.integer_estimator as estimator + + +@pytest.mark.parametrize( + 'weights, latest_estimates, estimate', + [ + ( + {0: 5}, + {0: 5}, + 5 + ), + ( + {0: 5, 1: 6, 2: 7}, + {0: 5, 1: 5, 2: 5}, + 5 + ), + ( + {0: 5, 1: 10, 2: 14}, + {0: 0, 1: 5, 2: 10}, + 5 + ), + ( + {0: 5, 1: 11}, + {0: 0, 1: 6}, + 6 + ), + ( + {0: 5, 1: 10, 2: 14}, + {0: 0, 1: 0, 2: 1}, + 0 + ), + ( + {0: 5, 1: 5}, + {0: 0, 1: 1}, + 0 + ), + ] +) +def test_estimator_picks_correct_estimate(weights, latest_estimates, estimate, empty_just): + validator_set = ValidatorSet(weights, IntegerProtocol) + + latest_messages = dict() + for val_name in latest_estimates: + validator = validator_set.get_validator_by_name(val_name) + latest_messages[validator] = Bet( + latest_estimates[val_name], empty_just, validator, 1, 1 + ) + + assert estimate == estimator.get_estimate_from_latest_messages(latest_messages) diff --git a/tests/casper/protocols/integer/test_integer_message.py b/tests/casper/protocols/integer/test_integer_message.py new file mode 100644 index 0000000..1b4c20d --- /dev/null +++ b/tests/casper/protocols/integer/test_integer_message.py @@ -0,0 +1,39 @@ +import pytest + +from casper.protocols.integer.bet import Bet + + +@pytest.mark.parametrize( + 'estimate, is_valid', + [ + (0, True), + (1, True), + (10000000, True), + (-10000000, True), + (True, True), + (False, True), + ((0, 1), False), + (None, False), + ] +) +def test_accepts_valid_estimates(estimate, is_valid): + assert Bet.is_valid_estimate(estimate) == is_valid + + +@pytest.mark.parametrize( + 'estimate_one, estimate_two, conflicts', + [ + (0, 0, False), + (1, 1, False), + (1000, 1000, False), + (-1000, -1000, False), + (1000, -1000, True), + (True, 10, True), + (10, False, True), + ] +) +def test_conflicts_with(estimate_one, estimate_two, conflicts, create_bet): + bet_one = create_bet(estimate_one) + bet_two = create_bet(estimate_two) + + assert bet_one.conflicts_with(bet_two) == conflicts diff --git a/tests/casper/protocols/integer/test_integer_test_lang.py b/tests/casper/protocols/integer/test_integer_test_lang.py new file mode 100644 index 0000000..9447b78 --- /dev/null +++ b/tests/casper/protocols/integer/test_integer_test_lang.py @@ -0,0 +1,119 @@ +"""The language testing module ... """ +import pytest + +from state_languages.integer_test_lang import IntegerTestLang +from casper.network import Network +from casper.validator_set import ValidatorSet + + +def test_init_creates_state_lang(test_weight): + integer_lang = IntegerTestLang(test_weight, False) + + integer_lang.messages + integer_lang.plot_tool + + assert isinstance(integer_lang.network, Network) + assert isinstance(integer_lang.validator_set, ValidatorSet) + + assert len(integer_lang.validator_set) == len(test_weight) + + # should only have seen their intial message + for validator in integer_lang.validator_set: + assert len(validator.view.justified_messages) == 1 + + +@pytest.mark.parametrize( + 'test_string, error', + [ + ('CE0-A', ValueError), + ('CS0-A', ValueError), + ('CU0-A', ValueError), + ] +) +def test_only_integer_estimates(test_string, error, integer_lang): + integer_lang.parse('M0-A') + + with pytest.raises(error): + integer_lang.parse(test_string) + + +def test_check_estimate_passes_on_valid_assertions(integer_lang): + integer_lang.parse('M0-A S1-A S2-A S3-A S4-A') + + current_estimates = dict() + for validator in integer_lang.validator_set: + current_estimates[validator] = validator.estimate() + + check_estimate = '' + for validator in integer_lang.validator_set: + check_estimate += 'CE' + str(validator.name) + '-' + str(current_estimates[validator]) + ' ' + check_estimate = check_estimate[:-1] + + integer_lang.parse(check_estimate) + + +@pytest.mark.parametrize( + 'test_string', + [ + ('M0-A CE0-0 CE0-1'), + ('RR0-A RR0-B CE0-0 CE0-1'), + ('M0-A CS0-0 CS0-1'), + ('RR0-A RR0-B CS1-0 CS1-1'), + ] +) +def test_checks_fails_on_invalid_assertions(test_string, integer_lang): + with pytest.raises(AssertionError): + integer_lang.parse(test_string) + + +def test_check_safe_passes_on_valid_assertions(integer_lang): + integer_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = integer_lang.network.global_view.estimate() + + check_safe = '' + for validator in integer_lang.validator_set: + check_safe += 'CS' + str(validator.name) + '-' + str(current_estimate) + ' ' + check_safe = check_safe[:-1] + + integer_lang.parse(check_safe) + + +def test_check_unsafe_passes_on_valid_assertions(integer_lang): + integer_lang.parse('M0-A S1-A S2-A S3-A S4-A') + + current_estimates = dict() + for validator in integer_lang.validator_set: + current_estimates[validator] = validator.estimate() + + check_unsafe = '' + for validator in integer_lang.validator_set: + check_unsafe += 'CE' + str(validator.name) + '-' + str(current_estimates[validator]) + ' ' + check_unsafe = check_unsafe[:-1] + + integer_lang.parse(check_unsafe) + + for validator in integer_lang.validator_set: + assert validator.view.last_finalized_estimate is None + + +def test_check_unsafe_passes_on_valid_assertions_rr(integer_lang): + integer_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = integer_lang.network.global_view.estimate() + + check_unsafe = '' + for validator in integer_lang.validator_set: + check_unsafe += 'CU' + str(validator.name) + '-' + str(1 + current_estimate) + ' ' + check_unsafe = check_unsafe[:-1] + + integer_lang.parse(check_unsafe) + + +def test_check_unsafe_fails_on_invalid_assertions(integer_lang): + integer_lang.parse('RR0-A RR0-B RR0-C RR0-D') + + current_estimate = integer_lang.network.global_view.estimate() + + with pytest.raises(AssertionError): + integer_lang.parse('CU0-' + str(current_estimate)) diff --git a/tests/casper/protocols/integer/test_integer_view.py b/tests/casper/protocols/integer/test_integer_view.py new file mode 100644 index 0000000..82caaab --- /dev/null +++ b/tests/casper/protocols/integer/test_integer_view.py @@ -0,0 +1,7 @@ +import pytest + +from casper.protocols.integer.integer_view import IntegerView + +@pytest.mark.skip(reason="cannot specify inital messages yet") +def test_update_safe_estimates(): + pass diff --git a/tests/casper/protocols/order/conftest.py b/tests/casper/protocols/order/conftest.py new file mode 100644 index 0000000..4f1fe79 --- /dev/null +++ b/tests/casper/protocols/order/conftest.py @@ -0,0 +1,25 @@ +import random +import pytest + +from casper.protocols.order.order_protocol import OrderProtocol + +@pytest.fixture +def order_validator_set(generate_validator_set): + return generate_validator_set(OrderProtocol) + + +@pytest.fixture +def order_validator(order_validator_set): + return random.choice(list(order_validator_set)) + + +@pytest.fixture +def bet(empty_just, order_validator): + return OrderProtocol.Message([0, 1], empty_just, order_validator, 0, 0) + + +@pytest.fixture +def create_bet(empty_just, order_validator): + def c_bet(estimate): + return OrderProtocol.Message(estimate, empty_just, order_validator, 0, 0) + return c_bet diff --git a/tests/casper/protocols/order/test_order_estimator.py b/tests/casper/protocols/order/test_order_estimator.py new file mode 100644 index 0000000..d9852d1 --- /dev/null +++ b/tests/casper/protocols/order/test_order_estimator.py @@ -0,0 +1,67 @@ +"""test order estimator""" +import pytest + +from casper.protocols.order.bet import Bet +from casper.validator_set import ValidatorSet +import casper.protocols.order.order_estimator as estimator + + +@pytest.mark.parametrize( + 'weights, latest_estimates, estimate', + [ + ( + {0: 5}, + {0: [1, 5, 6]}, + [1, 5, 6] + ), + ( + {0: 5, 1: 6, 2: 7}, + { + 0: [1, 2], + 1: [1, 2], + 2: [2, 1] + }, + [1, 2] + ), + ( + {0: 5, 1: 6, 2: 7}, + { + 0: [1, 2, 3], + 1: [2, 3, 1], + 2: [2, 1, 3] + }, # {2: 31, 1: 16, 3: 6} + [2, 1, 3] + ), + ( + {0: 5, 1: 10, 2: 14}, + { + 0: ["fish", "pig", "horse", "dog"], + 1: ["dog", "horse", "pig", "fish"], + 2: ["pig", "horse", "fish", "dog"] + }, # {"fish": 29, "pig": 62, "horse": 53, "dog": 30} + ["pig", "horse", "dog", "fish"] + ), + ( + {0: 5, 1: 6, 2: 7, 3: 8}, + { + 0: ["fish", "pig", "horse"], + 1: ["horse", "pig", "fish"], + 2: ["pig", "horse", "fish"], + 3: ["fish", "horse", "pig"] + }, # {"fish": 26, "pig": 25, "horse": 27} + ["horse", "fish", "pig"] + ), + + ] +) +def test_estimator_picks_correct_estimate(weights, latest_estimates, estimate, empty_just): + validator_set = ValidatorSet(weights) + + latest_messages = dict() + for val_name in latest_estimates: + validator = validator_set.get_validator_by_name(val_name) + latest_messages[validator] = Bet( + latest_estimates[val_name], empty_just, validator, 1, 1 + ) + + assert estimate == estimator.get_estimate_from_latest_messages(latest_messages) diff --git a/tests/casper/protocols/order/test_order_message.py b/tests/casper/protocols/order/test_order_message.py new file mode 100644 index 0000000..4a346de --- /dev/null +++ b/tests/casper/protocols/order/test_order_message.py @@ -0,0 +1,38 @@ +import pytest + +from casper.protocols.order.bet import Bet + + +@pytest.mark.parametrize( + 'estimate, is_valid', + [ + ([], True), + ([(0, 1, 2)], True), + ([True, False, True], True), + (-10000000, False), + (True, False), + ((0, 1), False), + (None, False), + ] +) +def test_accepts_valid_estimates(estimate, is_valid): + assert Bet.is_valid_estimate(estimate) == is_valid + + +@pytest.mark.parametrize( + 'estimate_one, estimate_two, conflicts', + [ + ([1], [1], False), + ([1, 2, 3], [1, 2, 3], False), + ([[]], [[]], False), + (['Hi', 'Hello'], ['Hi', 'Hello'], False), + ([1], [1, 2], True), + ([1, 2], [2, 1], True), + (['Hi', 'Hello'], ['Hi', 'Welcome'], True), + ] +) +def test_conflicts_with(estimate_one, estimate_two, conflicts, create_bet): + bet_one = create_bet(estimate_one) + bet_two = create_bet(estimate_two) + + assert bet_one.conflicts_with(bet_two) == conflicts diff --git a/tests/casper/protocols/order/test_order_view.py b/tests/casper/protocols/order/test_order_view.py new file mode 100644 index 0000000..5c2ff9b --- /dev/null +++ b/tests/casper/protocols/order/test_order_view.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.skip(reason="test language not written") +def test_update_safe_estimates(): + pass + + +@pytest.mark.skip(reason="test language not written") +def test_update_protocol_specific_view(): + pass diff --git a/tests/casper/safety_oracles/conftest.py b/tests/casper/safety_oracles/conftest.py new file mode 100644 index 0000000..80f28d3 --- /dev/null +++ b/tests/casper/safety_oracles/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from casper.safety_oracles.clique_oracle import CliqueOracle +from casper.safety_oracles.turan_oracle import TuranOracle +from casper.safety_oracles.adversary_oracle import AdversaryOracle + +ORACLES = [CliqueOracle, TuranOracle, AdversaryOracle] + + +@pytest.fixture(params=ORACLES) +def oracle_class(request): + return request.param diff --git a/tests/casper/safety_oracles/test_safety_oracle.py b/tests/casper/safety_oracles/test_safety_oracle.py new file mode 100644 index 0000000..87ecc38 --- /dev/null +++ b/tests/casper/safety_oracles/test_safety_oracle.py @@ -0,0 +1,114 @@ +"""The safety oracle testing module ... """ + + +def test_round_robin_safety(oracle_class, test_lang_creator): + test_lang = test_lang_creator({0: 9.3, 1: 8.2, 2: 7.1, 3: 6, 4: 5}) + test_string = ( + 'P M0-A SJ1-A RR1-B RR1-C RR1-D RR1-E SJ2-E ' + 'SJ3-E SJ4-E' + ) + + test_lang.parse(test_string) + + message = test_lang.messages['A'] + + for validator in test_lang.validator_set: + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance > 0 + + +def test_majority_fork_safe(oracle_class, test_lang_creator): + test_lang = test_lang_creator({0: 5, 1: 6, 2: 7}) + test_string = ( + # create right hand side of fork and check for safety + 'P M1-A SJ0-A M0-L0 SJ1-L0 M1-L1 SJ0-L1 M0-L2 SJ1-L2 ' + 'M1-L3 SJ0-L3 M0-L4 SJ1-L4' + ) + test_lang.parse(test_string) + + message = test_lang.messages['L0'] + + for validator in test_lang.validator_set.get_validators_by_names([0, 1]): + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance > 0 + + +def test_no_majority_fork_unsafe(oracle_class, test_lang_creator): + test_lang = test_lang_creator({0: 5, 1: 4.5, 2: 6, 3: 4, 4: 5.25}) + test_string = ( + # create right hand side of fork and check for no safety + 'M2-A SJ1-A M1-L0 SJ0-L0 M0-L1 SJ1-L1 M1-L2 SJ0-L2 ' + 'M0-L3 SJ1-L3 M1-L4 SJ0-L4 ' + # now, left hand side as well. should still have no safety + 'SJ3-A M3-R0 SJ4-R0 M4-R1 SJ3-R1 M3-R2 SJ4-R2 M4-R3 ' + 'SJ3-R3 M3-R4 SJ4-R4' + ) + test_lang.parse(test_string) + + # left hand side of fork + for validator in test_lang.validator_set.get_validators_by_names([0, 1]): + message = test_lang.messages['L0'] + + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance == 0 + + # right hand side of fork + for validator in test_lang.validator_set.get_validators_by_names([2, 3]): + message = test_lang.messages['R0'] + + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance == 0 + + +def test_no_majority_fork_safe_after_union(oracle_class, test_lang_creator): + test_lang = test_lang_creator({0: 5, 1: 4.5, 2: 6, 3: 4, 4: 5.25}) + test_string = ( + # generate both sides of an extended fork + 'M2-A SJ1-A M1-L0 SJ0-L0 M0-L1 SJ1-L1 M1-L2 SJ0-L2 ' + 'M0-L3 SJ1-L3 M1-L4 SJ0-L4 ' + 'SJ3-A M3-R0 SJ4-R0 M4-R1 SJ3-R1 M3-R2 SJ4-R2 M4-R3 ' + 'SJ3-R3 M3-R4 SJ4-R4' + ) + test_lang.parse(test_string) + + # left hand side of fork + for validator in test_lang.validator_set.get_validators_by_names([0, 1]): + message = test_lang.messages['L0'] + + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance == 0 + + # right hand side of fork + for validator in test_lang.validator_set.get_validators_by_names([2, 3]): + message = test_lang.messages['R0'] + + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance == 0 + + test_string = ( + # show all validators all messages + 'SJ0-R4 SJ1-R4 SJ2-R4 SJ2-L4 SJ3-L4 SJ4-L4 ' + # two rounds of round robin, check have safety on the correct fork + 'RR0-J0 RR0-J1' + ) + test_lang.parse(test_string) + + validator = test_lang.validator_set.get_validator_by_name(0) + message = test_lang.messages['L0'] + + oracle = oracle_class(message, validator.view, test_lang.validator_set) + fault_tolerance, _ = oracle.check_estimate_safety() + + assert fault_tolerance > 0 diff --git a/tests/casper/test_forkchoice.py b/tests/casper/test_forkchoice.py deleted file mode 100644 index d26c897..0000000 --- a/tests/casper/test_forkchoice.py +++ /dev/null @@ -1,102 +0,0 @@ -"""The forkchoice testing module ... """ -import random as r -import pytest - -import casper.blockchain.forkchoice as forkchoice - - -def test_single_validator_correct_forkchoice(test_lang_runner): - """ This tests that a single validator remains on their own chain """ - test_string = "" - for i in range(100): - test_string += "B0-" + str(i) + " " + "H0-" + str(i) + " " - test_string = test_string[:-1] - - test_lang_runner(test_string, {0: 10}) - - -def test_two_validators_round_robin_forkchoice(test_lang_runner): - test_string = "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D H0-D R" - test_lang_runner(test_string, {0: 10, 1: 11}) - - -def test_many_val_round_robin_forkchoice(test_lang_runner): - """ - Tests that during a perfect round robin, - validators choose the one chain as their fork choice - """ - test_string = "" - for i in range(100): - test_string += "B" + str(i % 10) + "-" + str(i) + " " \ - + "S" + str((i+1) % 10) + "-" + str(i) + " " \ - + "H" + str((i+1) % 10) + "-" + str(i) + " " - test_string = test_string[:-1] - - test_lang_runner( - test_string, - {i: 10 - i + r.random() for i in range(10)} - ) - - -def test_fail_on_tie(test_lang_runner): - """ - Tests that if there are two subsets of the validator - set with the same weight, the forkchoice fails - """ - test_string = "B1-A S0-A B0-B S1-B S2-A B2-C S1-C H1-C" - with pytest.raises(AssertionError): - test_lang_runner(test_string, {0: 5, 1: 6, 2: 5}) - - -def test_ignore_zero_weight_validator(test_lang_runner): - """ - Tests that a validator with zero weight - will not affect the forkchoice - """ - test_string = "B0-A S1-A B1-B S0-B H1-A H0-A" - test_lang_runner(test_string, {0: 1, 1: 0}) - - -def test_ignore_zero_weight_block(test_lang_runner): - """ Tests that the forkchoice ignores zero weight blocks """ - # for more info about test, see - # https://gist.github.com/naterush/8d8f6ec3509f50939d7911d608f912f4 - test_string = ( - "B0-A1 B0-A2 H0-A2 B1-B1 B1-B2 S3-B2 B3-D1 H3-D1 " - "S3-A2 H3-A2 B3-D2 S2-B1 H2-B1 B2-C1 H2-C1 S1-D1 " - "S1-D2 S1-C1 H1-B2" - ) - test_lang_runner(test_string, {0: 10, 1: 9, 2: 8, 3: 0.5}) - - -def test_reverse_message_arrival_order_forkchoice_four_val(test_lang_runner): - test_string = ( - "B0-A S1-A B1-B S0-B B0-C S1-C B1-D S0-D B1-E S0-E " - "S2-E H2-E S3-A S3-B S3-C S3-D S3-E H3-E" - ) - test_lang_runner(test_string, {0: 5, 1: 6, 2: 7, 3: 8.1}) - - -@pytest.mark.skip(reason="test not yet implemented") -def test_different_message_arrival_order_forkchoice_many_val(): - pass - - -@pytest.mark.parametrize( - 'weights, expected', - [ - ({i: i for i in range(10)}, {9}), - ({i: 9 - i for i in range(10)}, {0}), - ({i: i % 5 for i in range(10)}, {4, 9}), - ({i: 10 for i in range(10)}, {i for i in range(10)}), - ({}, ValueError), - ({i: 0 for i in range(10)}, AssertionError), - ] -) -def test_max_weight_indexes(weights, expected): - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - forkchoice.get_max_weight_indexes(weights) - return - - assert forkchoice.get_max_weight_indexes(weights) == expected diff --git a/tests/casper/test_message.py b/tests/casper/test_message.py index 7ac2141..38967d2 100644 --- a/tests/casper/test_message.py +++ b/tests/casper/test_message.py @@ -1,35 +1,5 @@ -import copy +def test_message_implements_interface(message): + assert callable(message.is_valid_estimate) + assert callable(message.conflicts_with) -from casper.justification import Justification -from casper.message import Message -from casper.validator import Validator - - -def test_new_message(validator): - message = Message(None, Justification(), validator) - - assert message.sender == validator - assert message.estimate is None - assert not message.justification.latest_messages - assert message.sequence_number == 0 - - -def test_equality_of_copies_off_genesis(validator): - message = Message(None, Justification(), validator) - - shallow_copy = copy.copy(message) - deep_copy = copy.deepcopy(message) - - assert message == shallow_copy - assert message == deep_copy - assert shallow_copy == deep_copy - - -def test_non_equality_of_copies_off_genesis(): - validator_0 = Validator("v0", 10) - validator_1 = Validator("v1", 11) - - message_0 = Message(None, Justification(), validator_0) - message_1 = Message(None, Justification(), validator_1) - - assert message_0 != message_1 + assert message.is_valid_estimate(0) is not None diff --git a/tests/casper/test_network.py b/tests/casper/test_network.py index cb2d59f..b434a1c 100644 --- a/tests/casper/test_network.py +++ b/tests/casper/test_network.py @@ -1,44 +1,156 @@ """The network testing module ... """ -import random as r -import pytest - from casper.network import Network -def test_new_network(validator_set): - network = Network(validator_set) +def test_new_network_genesis(generate_validator_set, genesis_protocol): + validator_set = generate_validator_set(genesis_protocol) + network = Network(validator_set, genesis_protocol) + assert network.validator_set == validator_set + assert len(network.global_view.justified_messages) == 1 + + +def test_new_network_rand_start(generate_validator_set, rand_start_protocol): + validator_set = generate_validator_set(rand_start_protocol) + network = Network(validator_set, rand_start_protocol) assert network.validator_set == validator_set - assert not network.global_view.messages + assert len(network.global_view.justified_messages) == len(validator_set) + + +def test_default_time(network): + assert network.time == 0 + network.advance_time() + assert network.time == 1 + network.advance_time() + assert network.time == 2 + + jump = 50 + network.advance_time(jump) + assert network.time == 2 + jump + + +def test_send(network, from_validator, to_validator): + message = from_validator.make_new_message() + assert network.message_queues[to_validator].qsize() == 0 + + network.send(to_validator, message) + assert network.message_queues[to_validator].qsize() == 1 + + +def test_send_adds_to_global_view(network, global_view, from_validator, to_validator): + message = from_validator.make_new_message() + num_justified = len(global_view.justified_messages) + network.send(to_validator, message) + + assert len(global_view.justified_messages) == num_justified + 1 + + +def test_send_zero_delay(no_delay_network, from_validator, to_validator): + message = from_validator.make_new_message() + no_delay_network.send(to_validator, message) + + message_queue = no_delay_network.message_queues[to_validator] + assert message_queue.qsize() == 1 + assert message_queue.peek()[0] == no_delay_network.time + + +def test_send_constant_delay(constant_delay_network, from_validator, to_validator): + network = constant_delay_network + message = from_validator.make_new_message() + network.send(to_validator, message) + + message_queue = network.message_queues[to_validator] + assert message_queue.qsize() == 1 + assert message_queue.peek()[0] == network.time + network.CONSTANT + + +def test_send_to_all(network, from_validator): + message = from_validator.make_new_message() + for validator in network.validator_set: + assert network.message_queues[validator].qsize() == 0 + + network.send_to_all(message) + assert network.message_queues[from_validator].qsize() == 0 + for validator in network.validator_set: + if validator == from_validator: + continue + message_queue = network.message_queues[validator] + assert message_queue.qsize() == 1 + assert message_queue.queue[0][1] == message + + +def test_receive_empty(network, validator): + assert network.message_queues[validator].qsize() == 0 + assert network.receive(validator) is None + + +def test_receive_before_delay(constant_delay_network, to_validator, from_validator): + network = constant_delay_network + message = from_validator.make_new_message() + network.send(to_validator, message) + + message_queue = network.message_queues[to_validator] + assert message_queue.qsize() == 1 + assert message_queue.peek()[0] > constant_delay_network.time + + assert network.receive(to_validator) is None + assert network.message_queues[to_validator].qsize() == 1 + + +def test_receive_at_delay(constant_delay_network, from_validator, to_validator): + network = constant_delay_network + message = from_validator.make_new_message() + network.send(to_validator, message) + + assert network.message_queues[to_validator].qsize() == 1 + network.advance_time(network.CONSTANT) + assert network.message_queues[to_validator].peek()[0] == network.time + + assert network.receive(to_validator) is message + assert network.message_queues[to_validator].qsize() == 0 -def test_random_initialization(validator_set): - network = Network(validator_set) +def test_receive_after_delay(constant_delay_network, from_validator, to_validator): + network = constant_delay_network + message = from_validator.make_new_message() + network.send(to_validator, message) - assert not network.global_view.messages - network.random_initialization() - assert len(network.global_view.messages) == len(validator_set) + assert network.message_queues[to_validator].qsize() == 1 + network.advance_time(network.CONSTANT * 3) + assert network.message_queues[to_validator].peek()[0] < network.time + assert network.receive(to_validator) is message + assert network.message_queues[to_validator].qsize() == 0 -def test_get_message_from_validator(network): - validator = r.sample(network.validator_set.validators, 1)[0] - message = network.get_message_from_validator(validator) - assert message.sender == validator +def test_receive_multiple_after_delay(constant_delay_network, from_validator, to_validator): + network = constant_delay_network + num_messages_to_send = 3 + messages = [] + for i in range(num_messages_to_send): + message = from_validator.make_new_message() + network.send(to_validator, message) + messages.append(message) + assert network.message_queues[to_validator].qsize() == num_messages_to_send + network.advance_time(network.CONSTANT * 3) -def test_propagate_message_to_validator(network): - from_validator, to_validator = r.sample( - network.validator_set.validators, - 2 - ) + for i in range(num_messages_to_send): + message = network.receive(to_validator) + assert message in messages + messages.remove(message) - message = network.get_message_from_validator(from_validator) - network.propagate_message_to_validator(message, to_validator) - assert message in to_validator.view.messages - assert message == to_validator.view.latest_messages[from_validator] +def test_receive_all_available_after_delay(constant_delay_network, from_validator, to_validator): + network = constant_delay_network + num_messages_to_send = 3 + messages = [] + for i in range(num_messages_to_send): + message = from_validator.make_new_message() + network.send(to_validator, message) + messages.append(message) + assert network.message_queues[to_validator].qsize() == num_messages_to_send + network.advance_time(network.CONSTANT * 3) -@pytest.mark.skip(reason="test not yet implemented") -def test_view_initialization(): - pass + received_messages = network.receive_all_available(to_validator) + assert set(messages) == set(received_messages) diff --git a/tests/casper/test_protocol.py b/tests/casper/test_protocol.py index 93ae16b..10487f5 100644 --- a/tests/casper/test_protocol.py +++ b/tests/casper/test_protocol.py @@ -1,18 +1,3 @@ -import pytest - -from casper.protocol import Protocol -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.binary.binary_protocol import BinaryProtocol - - -@pytest.mark.parametrize( - 'protocol', - ( - Protocol, - BlockchainProtocol, - BinaryProtocol, - ) -) def test_class_properties_defined(protocol): protocol.View protocol.Message diff --git a/tests/casper/test_safety_oracle.py b/tests/casper/test_safety_oracle.py deleted file mode 100644 index 2ca6155..0000000 --- a/tests/casper/test_safety_oracle.py +++ /dev/null @@ -1,53 +0,0 @@ -"""The safety oracle testing module ... """ - -def test_round_robin_safety(test_lang_runner): - test_string = ( - 'R B0-A S1-A RR1-B RR1-C RR1-D RR1-E S2-E ' - 'S3-E S4-E H0-E H1-E H2-E H3-E H4-E C0-A ' - 'C1-A C2-A C3-A C4-A R' - ) - weights = {0: 9.3, 1: 8.2, 2: 7.1, 3: 6, 4: 5} - test_lang_runner(test_string, weights) - - -def test_majority_fork_safe(test_lang_runner): - test_string = ( - # create right hand side of fork and check for safety - 'R B1-A S0-A B0-L0 S1-L0 B1-L1 S0-L1 B0-L2 S1-L2 ' - 'B1-L3 S0-L3 B0-L4 S1-L4 H1-L4 C1-L0 H0-L4 C0-L0 R ' - # other fork shows safe fork blocks, but they remain stuck - 'S2-A B2-R0 S0-R0 H0-L4 S1-R0 H0-L4 R' - ) - weights = {0: 5, 1: 6, 2: 7} - test_lang_runner(test_string, weights) - - -def test_no_majority_fork_unsafe(test_lang_runner): - test_string = ( - # create right hand side of fork and check for no safety - 'R B2-A S1-A B1-L0 S0-L0 B0-L1 S1-L1 B1-L2 S0-L2 ' - 'B0-L3 S1-L3 B1-L4 S0-L4 H0-L4 U0-L0 H1-L4 U1-L0 R ' - # now, left hand side as well. still no safety - 'S3-A B3-R0 S4-R0 B4-R1 S3-R1 B3-R2 S4-R2 B4-R3 ' - 'S3-R3 B3-R4 S4-R4 H4-R4 U4-R0 H3-R4 U3-R0 R' - ) - weights = {0: 5, 1: 4.5, 2: 6, 3: 4, 4: 5.25} - test_lang_runner(test_string, weights) - - -def test_no_majority_fork_safe_after_union(test_lang_runner): - test_string = ( - # generate both sides of an extended fork - 'R B2-A S1-A B1-L0 S0-L0 B0-L1 S1-L1 B1-L2 S0-L2 ' - 'B0-L3 S1-L3 B1-L4 S0-L4 H0-L4 U0-L0 H1-L4 U1-L0 R ' - 'S3-A B3-R0 S4-R0 B4-R1 S3-R1 B3-R2 S4-R2 B4-R3 ' - 'S3-R3 B3-R4 S4-R4 H4-R4 U4-R0 H3-R4 U3-R0 R ' - # show all validators all blocks - 'S0-R4 S1-R4 S2-R4 S2-L4 S3-L4 S4-L4 ' - # check all have correct forkchoice - 'H0-L4 H1-L4 H2-L4 H3-L4 H4-L4 ' - # two rounds of round robin, check have safety on the correct fork - 'RR0-J0 RR0-J1 C0-L0 R' - ) - weights = {0: 5, 1: 4.5, 2: 6, 3: 4, 4: 5.25} - test_lang_runner(test_string, weights) diff --git a/tests/casper/test_state_lang.py b/tests/casper/test_state_lang.py new file mode 100644 index 0000000..6ed1d07 --- /dev/null +++ b/tests/casper/test_state_lang.py @@ -0,0 +1,296 @@ +"""The language testing module ... """ +import pytest + +from state_languages.state_language import StateLanguage +from casper.network import Network + + +def test_init(protocol, test_weight): + state_lang = StateLanguage(test_weight, protocol, False) + + assert isinstance(state_lang.plot_tool, protocol.PlotTool) + assert len(state_lang.validator_set) == len(test_weight) + assert state_lang.validator_set.validator_weights() == set(test_weight.values()) + assert state_lang.validator_set.weight() == sum(test_weight.values()) + + assert isinstance(state_lang.network, Network) + + +def test_registers_handlers(protocol, test_weight): + state_lang = StateLanguage(test_weight, protocol, False) + + assert callable(state_lang.make_message) + assert callable(state_lang.make_invalid) + assert callable(state_lang.send_message) + assert callable(state_lang.plot) + assert callable(state_lang.check_estimate) + assert callable(state_lang.check_safe) + assert callable(state_lang.check_unsafe) + + assert state_lang.make_message == state_lang.handlers['M'] + assert state_lang.make_invalid == state_lang.handlers['I'] + assert state_lang.send_message == state_lang.handlers['S'] + assert state_lang.plot == state_lang.handlers['P'] + assert state_lang.check_estimate == state_lang.handlers['CE'] + assert state_lang.check_safe == state_lang.handlers['CS'] + assert state_lang.check_unsafe == state_lang.handlers['CU'] + + +@pytest.mark.parametrize( + 'handler, error', + ( + ('M', KeyError), + ('S', KeyError), + ('I', KeyError), + ('P', KeyError), + ('CE', KeyError), + ('CS', KeyError), + ('CU', KeyError), + ('CV', None), + ('H', None), + ('123-123', None), + ) +) +def test_allows_new_handlers_to_register(handler, error, protocol, test_weight, example_function): + state_lang = StateLanguage(test_weight, protocol, False) + + if isinstance(error, type) and issubclass(error, Exception): + with pytest.raises(error): + state_lang.register_handler(handler, example_function) + return + + state_lang.register_handler(handler, example_function) + assert callable(state_lang.handlers[handler]) + assert state_lang.handlers[handler] == example_function + + +def test_init_validators_have_only_inital_messages(protocol, test_weight): + state_lang = StateLanguage(test_weight, protocol, False) + + for validator in state_lang.network.validator_set: + assert len(validator.view.justified_messages) == 1 + + +@pytest.mark.parametrize( + 'test_string, error', + [ + ('M0-A', None), + ('M0-A S1-A', None), + ('M0-A M0-B', None), + ('M0-A M0-B S1-A S1-B', None), + ('I1-B', NotImplementedError), + ('A-B', ValueError), + ('BA0-A, S1-A', KeyError), + ('BA0-A S1-A', KeyError), + ('RR0-A-A', ValueError), + ('B0-A S1-A T1-A', KeyError), + ('RR0-AB1-A', ValueError), + ('RRR', ValueError), + ('A0-A S1-A', KeyError), + ] +) +def test_parse_only_valid_tokens(test_string, test_weight, error, protocol): + state_lang = StateLanguage(test_weight, protocol, False) + + if isinstance(error, type) and issubclass(error, Exception): + with pytest.raises(error): + state_lang.parse(test_string) + return + + state_lang.parse(test_string) + + +@pytest.mark.parametrize( + 'test_strings, error', + [ + (['M0-A', 'S1-A'], None), + (['M0-A', 'M0-B S1-A S1-B'], None), + (['M0-A', 'I1-B'], NotImplementedError), + (['M0-A', 'S1-A T1-A'], KeyError), + (['M1-A', 'RR0-AB1-A'], ValueError), + ] +) +def test_parse_only_valid_tokens_split_strings(test_strings, error, protocol, test_weight): + state_lang = StateLanguage(test_weight, protocol, False) + + if isinstance(error, type) and issubclass(error, Exception): + with pytest.raises(error): + for test_string in test_strings: + state_lang.parse(test_string) + return + + for test_string in test_strings: + state_lang.parse(test_string) + + +@pytest.mark.parametrize( + 'test_string, test_weight, exception', + [ + ('M0-A M1-B M2-C M3-D M4-E', {i: 5 - i for i in range(5)}, ''), + ('M0-A S1-A S2-A S3-A S4-A', {i: 5 - i for i in range(5)}, ''), + ('M0-A S1-A', {0: 1, 1: 2}, ''), + ('M5-A', {i: 5 - i for i in range(5)}, True), + ('M0-A S1-A', {0: 1}, True), + ('M0-A S1-A S2-A S3-A S4-A', {0: 0}, True), + ('M4-A S5-A', {i: 5 - i for i in range(5)}, True), + ('M0-A S1-B', {i: 5 - i for i in range(5)}, True), + ('M0-A M1-A', {i: 5 - i for i in range(5)}, True), + ] +) +def test_parse_only_valid_val_and_messages(test_string, test_weight, exception, protocol): + state_lang = StateLanguage(test_weight, protocol, False) + + if exception: + with pytest.raises(Exception): + state_lang.parse(test_string) + return + + state_lang.parse(test_string) + + +@pytest.mark.parametrize( + 'test_strings, test_weight, exception', + [ + (['M0-A M1-B', 'M2-C M3-D M4-E'], {i: 5 - i for i in range(5)}, ''), + (['M0-A S1-A', 'S2-A S3-A S4-A'], {i: 5 - i for i in range(5)}, ''), + (['M0-A', 'S1-A'], {0: 1, 1: 2}, ''), + (['M0-A', 'S1-A'], {0: 1}, True), + (['M0-A S1-A', 'S2-A S3-A S4-A'], {0: 0}, True), + (['M4-A', 'S5-A'], {i: 5 - i for i in range(5)}, True), + (['M0-A', 'S1-B'], {i: 5 - i for i in range(5)}, True), + (['M0-A', 'M1-A'], {i: 5 - i for i in range(5)}, True), + ] +) +def test_parse_only_valid_val_and_messages_split_strings( + test_strings, + test_weight, + exception, + protocol + ): + state_lang = StateLanguage(test_weight, protocol, False) + + if exception: + with pytest.raises(Exception): + for test_string in test_strings: + state_lang.parse(test_string) + return + + for test_string in test_strings: + state_lang.parse(test_string) + + +@pytest.mark.parametrize( + 'test_string, num_blocks, exception', + [ + ('M0-A', 1, ''), + ('M0-A S1-A', 1, ''), + ('M0-A S1-A M1-B', 2, ''), + ('M0-A M1-B M2-C M3-D M4-E', 5, ''), + ('M0-A S1-A S2-A S3-A S4-A', 1, ''), + ('M0-A M1-A', None, 'already exists'), + ('M0-A S1-A S2-A S3-A S4-A M4-B M4-A', None, 'already exists'), + ] +) +def test_make_adds_to_global_view_( + test_string, + test_weight, + num_blocks, + exception, + protocol + ): + state_lang = StateLanguage(test_weight, protocol, False) + + num_inital_blocks = len(state_lang.network.global_view.justified_messages) + if exception: + with pytest.raises(Exception, match=exception): + state_lang.parse(test_string) + return + + state_lang.parse(test_string) + assert len(state_lang.network.global_view.justified_messages) == num_blocks + num_inital_blocks + + +@pytest.mark.parametrize( + 'test_string, block_justification', + [ + ('M0-A', {'A': {0: "GEN"}}), + ('M0-A S1-A M1-B', {'B': {0: 'A'}}), + ( + 'M0-A S1-A M1-B S2-A S2-B M2-C S3-A S3-B S3-C M3-D S4-A S4-B S4-C S4-D M4-E', + {'E': {0: 'A', 1: 'B', 2: 'C', 3: 'D'}} + ), + ] +) +def test_make_messages_builds_on_view( + test_string, + block_justification, + test_weight, + genesis_protocol + ): + state_lang = StateLanguage(test_weight, genesis_protocol, False) + state_lang.parse(test_string) + global_view = state_lang.network.global_view + + for b in block_justification: + block = state_lang.messages[b] + assert len(block.justification) == len(block_justification[b]) + for validator_name in block_justification[b]: + block_in_justification = block_justification[b][validator_name] + validator = state_lang.validator_set.get_validator_by_name(validator_name) + + if block_in_justification: + message_hash = block.justification[validator] + justification_message = global_view.justified_messages[message_hash] + + if block_in_justification == "GEN": + assert global_view.genesis_block == justification_message + else: + assert state_lang.messages[block_in_justification] == justification_message + + +@pytest.mark.parametrize( + 'test_string, num_messages_per_view, message_keys', + [ + ( + 'M0-A S1-A', + {0: 2, 1: 2}, + {0: ['A'], 1: ['A']} + ), + ( + 'M0-A S1-A S2-A S3-A S4-A', + {0: 2, 1: 2, 2: 2, 3: 2, 4: 2}, + {i: ['A'] for i in range(5)} + ), + ( + 'M0-A S1-A M1-B S2-A S2-B M2-C S3-A S3-B S3-C M3-D S4-A S4-B S4-C S4-D M4-E', + {0: 2, 1: 3, 2: 4, 3: 5, 4: 6}, + { + 0: ['A'], + 1: ['A', 'B'], + 2: ['A', 'B', 'C'], + 3: ['A', 'B', 'C', 'D'], + 4: ['A', 'B', 'C', 'D', 'E'] + } + ), + ( + 'M0-A M0-B M0-C M0-D M0-E', + {0: 6, 1: 1, 2: 1, 3: 1, 4: 1}, + {0: ['A', 'B', 'C', 'D', 'E'], 1: [], 2: [], 3: [], 4: []} + ), + ] +) +def test_send_updates_val_view_genesis_protocols( + test_string, + test_weight, + num_messages_per_view, + message_keys, + genesis_protocol + ): + state_lang = StateLanguage(test_weight, genesis_protocol, False) + state_lang.parse(test_string) + + for validator_name in num_messages_per_view: + validator = state_lang.validator_set.get_validator_by_name(validator_name) + assert len(validator.view.justified_messages) == num_messages_per_view[validator_name] + for message_name in message_keys[validator_name]: + assert state_lang.messages[message_name] in validator.view.justified_messages.values() diff --git a/tests/casper/test_test_lang.py b/tests/casper/test_test_lang.py deleted file mode 100644 index 4a7b9b8..0000000 --- a/tests/casper/test_test_lang.py +++ /dev/null @@ -1,391 +0,0 @@ -"""The language testing module ... """ - -import pytest - -from casper.network import Network -from simulations.testing_language import TestLangCBC - -TEST_STRING = 'B0-A' -TEST_WEIGHT = {i: 5 - i for i in range(5)} - - -def test_init(): - TestLangCBC({0: 1}) - - -@pytest.mark.parametrize( - 'weights, num_val, total_weight', - [ - ({i: i for i in range(10)}, 10, 45), - ({i: 9 - i for i in range(10)}, 10, 45), - ({0: 0}, 1, 0), - ({}, 0, 0), - ] -) -def test_initialize_validator_set(weights, num_val, total_weight): - test_lang = TestLangCBC(weights) - validator_set = test_lang.validator_set - - assert len(validator_set) == num_val - assert validator_set.validator_weights() == set(weights.values()) - assert validator_set.weight() == total_weight - - -def test_init_creates_network(): - test_lang = TestLangCBC(TEST_WEIGHT) - - assert isinstance(test_lang.network, Network) - - -def test_init_validators_create_blocks(): - test_lang = TestLangCBC(TEST_WEIGHT) - - assert len(test_lang.network.global_view.messages) == len(TEST_WEIGHT) - - for validator in test_lang.network.validator_set: - assert len(validator.view.messages) == 1 - assert len(validator.view.latest_messages) == 1 - assert validator.view.latest_messages[validator].estimate is None - - -@pytest.mark.parametrize( - 'test_string, error', - [ - ('B0-A', None), - ('B0-A S1-A', None), - ('B0-A S1-A U1-A', None), - ('B0-A S1-A H1-A', None), - ('B0-A RR0-B RR0-C C0-A', None), - ('R', None), - ('A-B', KeyError), - ('BA0-A, S1-A', ValueError), - ('BA0-A S1-A', KeyError), - ('RR0-A-A', ValueError), - ('B0-A S1-A T1-A', KeyError), - ('RR0-AB1-A', ValueError), - ('RRR', KeyError), - ('A0-A S1-A', KeyError), - ] -) -def test_parse_only_valid_tokens(test_string, error): - test_lang = TestLangCBC(TEST_WEIGHT) - - if isinstance(error, type) and issubclass(error, Exception): - with pytest.raises(error): - test_lang.parse(test_string) - return - - test_lang.parse(test_string) - - -@pytest.mark.parametrize( - 'test_strings, error', - [ - (['B0-A', 'S1-A'], None), - (['B0-A', 'S1-A U1-A'], None), - (['B0-A', 'RR0-B RR0-C', 'C0-A'], None), - (['B0-A', 'S1-A T1-A'], KeyError), - (['B1-A', 'RR0-AB1-A'], ValueError), - ] -) -def test_parse_only_valid_tokens_split_strings(test_strings, error): - test_lang = TestLangCBC(TEST_WEIGHT) - - if isinstance(error, type) and issubclass(error, Exception): - with pytest.raises(error): - for test_string in test_strings: - test_lang.parse(test_string) - return - - for test_string in test_strings: - test_lang.parse(test_string) - - -@pytest.mark.parametrize( - 'test_string, val_weights, exception', - [ - ('B0-A B1-B B2-C B3-D B4-E', TEST_WEIGHT, ''), - ('B0-A S1-A S2-A S3-A S4-A', TEST_WEIGHT, ''), - ('B0-A S1-A U1-A', {0: 1, 1: 2}, ''), - ('B0-A S1-A H1-A', {0: 2, 1: 1}, ''), - ('RR0-A RR0-B C0-A', {0: 2, 1: 1}, ''), - ('B5-A', TEST_WEIGHT, 'Validator'), - ('B0-A S1-A', {0: 1}, 'Validator'), - ('B0-A S1-A S2-A S3-A S4-A', {0: 0}, 'Validator'), - ('B4-A S5-A', TEST_WEIGHT, 'Validator'), - ('B0-A S1-A U2-A', {0: 1, 1: 2}, 'Validator'), - ('B0-A S1-A H2-A', {0: 2, 1: 1}, 'Validator'), - ('RR0-A RR0-B C2-A', {0: 2, 1: 1}, 'Validator'), - ('B0-A S1-B', TEST_WEIGHT, 'Block'), - ('B0-A S1-A U1-B', TEST_WEIGHT, 'Block'), - ('B0-A S1-A H1-B', TEST_WEIGHT, 'Block'), - ('B0-A RR0-B RR0-C C0-D', TEST_WEIGHT, 'Block'), - ] -) -def test_parse_only_valid_val_and_blocks(test_string, val_weights, exception): - test_lang = TestLangCBC(val_weights) - - if exception: - with pytest.raises(ValueError, match=exception): - test_lang.parse(test_string) - return - - test_lang.parse(test_string) - - -@pytest.mark.parametrize( - 'test_strings, val_weights, exception', - [ - (['B0-A B1-B', 'B2-C B3-D B4-E'], TEST_WEIGHT, ''), - (['B0-A', 'S1-A S2-A S3-A S4-A'], TEST_WEIGHT, ''), - (['B0-A S1-A', 'U1-A'], {0: 1, 1: 2}, ''), - (['B0-A', 'S1-A', 'S2-A', 'S3-A', 'S4-A'], {0: 0}, 'Validator'), - (['B4-A', 'S5-A'], TEST_WEIGHT, 'Validator'), - (['RR0-A', 'RR0-B', 'C2-A'], {0: 2, 1: 1}, 'Validator'), - (['B0-A', 'S1-A', 'U1-B'], TEST_WEIGHT, 'Block'), - (['B0-A', 'RR0-B', 'RR0-C', 'C0-D'], TEST_WEIGHT, 'Block'), - ] -) -def test_parse_only_valid_val_and_blocks_split_strings(test_strings, val_weights, exception): - test_lang = TestLangCBC(val_weights) - - if exception: - with pytest.raises(ValueError, match=exception): - for test_string in test_strings: - test_lang.parse(test_string) - return - - for test_string in test_strings: - test_lang.parse(test_string) - - -# NOTE: network.global_view.messages starts with 5 messages from random_initialization -@pytest.mark.parametrize( - 'test_string, num_blocks, exception', - [ - ('B0-A', 6, ''), - ('B0-A S1-A', 6, ''), - ('B0-A S1-A U1-A B1-B', 7, ''), - ('B0-A S1-A H1-A B1-B', 7, ''), - ('B0-A RR0-B RR0-C C0-A B0-D', 17, ''), - ('B0-A B1-B B2-C B3-D B4-E', 10, ''), - ('B0-A S1-A S2-A S3-A S4-A', 6, ''), - ('RR0-A RR0-B', 15, ''), - ('B0-A B1-A', 6, 'already exists'), - ('B0-A S1-A S2-A S3-A S4-A B4-B B4-A', 6, 'already exists'), - ('RR0-A RR0-A', 15, 'already exists'), - ] -) -def test_make_blocks_makes_new_blocks_adds_global_view(test_string, num_blocks, exception): - test_lang = TestLangCBC(TEST_WEIGHT) - - if exception: - with pytest.raises(Exception, match=exception): - test_lang.parse(test_string) - return - - test_lang.parse(test_string) - assert len(test_lang.network.global_view.messages) == num_blocks - - -# NOTE: None means the block is not named by the testing language -# this means the block was a init block, or was created by round robin -@pytest.mark.parametrize( - 'test_string, block_justification', - [ - ('B0-A', {'A': {0: None}}), - ('B0-A S1-A B1-B', {'B': {0: 'A', 1: None}}), - ('RR0-A', {'A': {i: None for i in range(5)}}), - ( - 'RR0-A B0-B S1-B B1-C', - {'C': {0: 'B', 1: None, 2: None, 3: None, 4: None}} - ), - ( - 'B0-A S1-A B1-B S2-B B2-C S3-C B3-D S4-D B4-E', - {'E': {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: None}} - ), - ( - 'B0-A S1-A B1-B S2-B B2-C S3-C B3-D S4-D B4-E S0-E B0-F', - {'F': {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}} - ), - ] -) -def test_make_block_builds_on_entire_view(test_string, block_justification): - test_lang = TestLangCBC(TEST_WEIGHT) - test_lang.parse(test_string) - - for b in block_justification: - block = test_lang.blocks[b] - assert len(block.justification.latest_messages) == len(block_justification[b].keys()) - for validator_name in block_justification[b]: - block_in_justification = block_justification[b][validator_name] - validator = test_lang.validator_set.get_validator_by_name(validator_name) - - if block_in_justification: - validator_justification_message = block.justification.latest_messages[validator] - assert test_lang.blocks[block_in_justification] == validator_justification_message - - -@pytest.mark.parametrize( - 'test_string, exception', - [ - ('B0-A S1-A', ''), - ('B0-A S1-A S2-A S3-A S4-A', ''), - ('B0-A S1-A S1-A', 'has already seen block'), - ('B0-A S0-A', 'has already seen block'), - ('RR0-A RR0-B S0-A', 'has already seen block'), - ] -) -def test_send_block_sends_only_existing_blocks(test_string, exception): - test_lang = TestLangCBC(TEST_WEIGHT) - - if exception: - with pytest.raises(Exception, match=exception): - test_lang.parse(test_string) - return - - test_lang.parse(test_string) - - -@pytest.mark.parametrize( - 'test_string, num_messages_per_view, message_keys', - [ - ( - 'B0-A S1-A', - {0: 2, 1: 3}, - {0: ['A'], 1: ['A']} - ), - ( - 'B0-A S1-A S2-A S3-A S4-A', - {0: 2, 1: 3, 2: 3, 3: 3, 4: 3}, - {i: ['A'] for i in range(5)} - ), - ( - 'B0-A S1-A B1-B S2-B B2-C S3-C B3-D S4-D B4-E', - {0: 2, 1: 4, 2: 6, 3: 8, 4: 10}, - { - 0: ['A'], - 1: ['A', 'B'], - 2: ['A', 'B', 'C'], - 3: ['A', 'B', 'C', 'D'], - 4: ['A', 'B', 'C', 'D', 'E'] - } - ), - ( - 'B0-A B0-B B0-C B0-D B0-E', - {0: 6, 1: 1, 2: 1, 3: 1, 4: 1}, - {0: ['A', 'B', 'C', 'D', 'E'], 1: [], 2: [], 3: [], 4: []} - ), - ] -) -def test_send_block_updates_val_view(test_string, num_messages_per_view, message_keys): - test_lang = TestLangCBC(TEST_WEIGHT) - test_lang.parse(test_string) - - for validator_name in num_messages_per_view: - validator = test_lang.validator_set.get_validator_by_name(validator_name) - assert len(validator.view.messages) == num_messages_per_view[validator_name] - for message_name in message_keys[validator_name]: - assert test_lang.blocks[message_name] in validator.view.messages - - -@pytest.mark.parametrize( - 'test_string, num_messages_per_view, other_val_seen', - [ - ( - 'RR0-A', - {0: 10, 1: 4, 2: 6, 3: 8, 4: 10}, - { - 0: [0, 1, 2, 3, 4], - 1: [0, 1], - 2: [0, 1, 2], - 3: [0, 1, 2, 3], - 4: [0, 1, 2, 3, 4] - } - ), - ( - 'RR0-A RR0-B', - {0: 15, 1: 12, 2: 13, 3: 14, 4: 15}, - {i: list(range(5)) for i in range(5)} - ), - ( - 'B0-A S1-A B1-B RR1-C', - {0: 12, 1: 12, 2: 7, 3: 9, 4: 11}, - { - 0: [0, 1, 2, 3, 4], - 1: [0, 1, 2, 3, 4], - 2: [0, 1, 2], - 3: [0, 1, 2, 3], - 4: [0, 1, 2, 3, 4] - } - ), - ( - 'RR0-A B0-B S1-B RR1-C', - {0: 16, 1: 16, 2: 13, 3: 14, 4: 15}, - {i: list(range(5)) for i in range(5)} - ), - ] -) -def test_round_robin_updates_val_view(test_string, num_messages_per_view, other_val_seen): - test_lang = TestLangCBC(TEST_WEIGHT) - test_lang.parse(test_string) - - for validator_name in num_messages_per_view: - validator = test_lang.validator_set.get_validator_by_name(validator_name) - - assert len(validator.view.messages) == num_messages_per_view[validator_name] - assert len(validator.view.latest_messages) == len(other_val_seen[validator_name]) - for other_validator_name in other_val_seen[validator_name]: - other_validator = test_lang.validator_set.get_validator_by_name(other_validator_name) - - assert other_validator in validator.view.latest_messages - - -@pytest.mark.parametrize( - 'test_string, val_forkchoice', - [ - ('B0-A S1-A H1-A', {1: 'A'}), - ('RR0-A', {0: 'A'}), - ('B0-A S1-A S2-A S3-A S4-A H1-A H2-A H3-A H4-A', {i: 'A' for i in range(5)}), - ('RR0-A RR0-B H0-B', {0: 'B'}), - ( - 'B0-A S1-A B1-B S2-B B2-C S3-C B3-D S4-D B4-E H0-A H1-B H2-C H3-D H4-E', - {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'} - ), - ] -) -def test_head_equals_block_checks_forkchoice(test_string, val_forkchoice): - test_lang = TestLangCBC(TEST_WEIGHT) - test_lang.parse(test_string) - - for validator_name in val_forkchoice: - validator = test_lang.validator_set.get_validator_by_name(validator_name) - block_name = val_forkchoice[validator_name] - assert test_lang.blocks[block_name] == validator.estimate() - - -@pytest.mark.parametrize( - 'test_string, error', - [ - (('NOFINAL'), ''), - ('RR0-A RR0-B RR0-C RR0-D RR0-E RR0-F U0-A', 'failed no-safety assert'), - ] -) -def test_no_safety(test_string, error): - if test_string == 'NOFINAL': - test_string = '' - for i in range(100): - test_string += 'B' + str(i % len(TEST_WEIGHT)) + '-' + str(i) + ' ' + \ - 'B' + str((i + 1) % len(TEST_WEIGHT)) + '-' + str(100 + i) + ' ' + \ - 'S' + str((i + 2) % len(TEST_WEIGHT)) + '-' + str(100 + i) + ' ' +\ - 'S' + str((i + 1) % len(TEST_WEIGHT)) + '-' + str(i) + ' ' - test_string += 'U0-0' - - test_lang = TestLangCBC(TEST_WEIGHT) - - if error: - with pytest.raises(Exception, match=error): - test_lang.parse(test_string) - return - - test_lang.parse(test_string) diff --git a/tests/casper/test_utils.py b/tests/casper/test_utils.py index 13269f8..399e18b 100644 --- a/tests/casper/test_utils.py +++ b/tests/casper/test_utils.py @@ -13,8 +13,8 @@ ({i: i for i in range(10)}, 45, None), ({i: 9 - i for i in range(9, -1, -1)}, 45, None), ({i: r.random() for i in range(10)}, None, None), - ({i: i*2 for i in range(10)}, 12, set([0, 1, 2, 3])), - ({i: i*2 for i in range(10)}, 12, [0, 1, 2, 3]), + ({i: i * 2 for i in range(10)}, 12, set([0, 1, 2, 3])), + ({i: i * 2 for i in range(10)}, 12, [0, 1, 2, 3]), ] ) def test_get_weight(weights, expected_weight, validator_names): diff --git a/tests/casper/test_validator.py b/tests/casper/test_validator.py index e1e6d50..eea98d2 100644 --- a/tests/casper/test_validator.py +++ b/tests/casper/test_validator.py @@ -1,9 +1,6 @@ """The validator testing module ... """ - import pytest -from casper.blockchain.block import Block -from casper.justification import Justification from casper.validator import Validator @@ -30,9 +27,11 @@ def test_new_validator(name, weight, error): assert validator.weight == weight -def test_check_estimate_safety_without_validator_set(): - validator = Validator("cool", 10.2) - block = Block(None, Justification(), validator) +def test_validator_created_with_genesis(genesis_protocol): + validator = Validator(0, 1, genesis_protocol) + assert validator.view.last_finalized_block is not None + - with pytest.raises(AttributeError): - validator.check_estimate_safety(block) +def test_validator_created_with_inital_message(rand_start_protocol): + validator = Validator(0, 1, rand_start_protocol) + assert validator.my_latest_message() is not None diff --git a/tests/casper/test_validator_set.py b/tests/casper/test_validator_set.py index 288ec5b..fbe21f3 100644 --- a/tests/casper/test_validator_set.py +++ b/tests/casper/test_validator_set.py @@ -2,6 +2,7 @@ import random as r import pytest +import itertools from casper.validator_set import ValidatorSet @@ -87,8 +88,8 @@ def test_validator_weights(weights, expected_weights): ({i: i for i in range(10)}, 45, None), ({i: 9 - i for i in range(9, -1, -1)}, 45, None), ({i: r.random() for i in range(10)}, None, None), - ({i: i*2 for i in range(10)}, 12, set([0, 1, 2, 3])), - ({i: i*2 for i in range(10)}, 12, [0, 1, 2, 3]), + ({i: i * 2 for i in range(10)}, 12, set([0, 1, 2, 3])), + ({i: i * 2 for i in range(10)}, 12, [0, 1, 2, 3]), ] ) def test_weight(weights, expected_weight, validator_names): @@ -104,11 +105,42 @@ def test_weight(weights, expected_weight, validator_names): assert round(val_set.weight(validators), 2) == round(expected_weight, 2) -@pytest.mark.skip(reason="test not yet implemented") -def test_get_validator_by_name(): - pass +@pytest.mark.parametrize( + 'weights', + [ + ({i: i for i in range(10)}), + ({i: 9 - i for i in range(9, -1, -1)}), + ({i: r.random() for i in range(10)}), + ({i: i*2 for i in range(10)}), + ({i: i*2 for i in range(10)}), + ] +) +def test_get_validator_by_name(weights): + val_set = ValidatorSet(weights) + + for validator in val_set: + returned_val = val_set.get_validator_by_name(validator.name) + assert validator == returned_val + + +@pytest.mark.parametrize( + 'weights', + [ + ({i: i for i in range(10)}), + ({i: 9 - i for i in range(9, -1, -1)}), + ({i: r.random() for i in range(10)}), + ({i: i*2 for i in range(10)}), + ({i: i*2 for i in range(10)}), + ] +) +def test_get_validators_by_names(weights): + val_set = ValidatorSet(weights) + for i in range(1, len(weights)): + val_subsets = itertools.combinations(val_set, i) + for subset in val_subsets: + subset = {val for val in subset} + val_names = {validator.name for validator in subset} + returned_set = val_set.get_validators_by_names(val_names) -@pytest.mark.skip(reason="test not yet implemented") -def test_get_validators_by_names(): - pass + assert subset == returned_set diff --git a/tests/casper/test_view.py b/tests/casper/test_view.py index cb9f3f1..e598426 100644 --- a/tests/casper/test_view.py +++ b/tests/casper/test_view.py @@ -1,25 +1,169 @@ import pytest from casper.abstract_view import AbstractView +from casper.protocols.blockchain.block import Block +from state_languages.blockchain_test_lang import BlockchainTestLang + + +TEST_WEIGHT = {0: 10, 1: 11} def test_new_view(): view = AbstractView() - assert not view.messages - assert not view.latest_messages - assert not view.justification().latest_messages + assert not any(view.justified_messages) + assert not any(view.latest_messages) + + +def test_justification_stores_hash(): + test_lang = BlockchainTestLang(TEST_WEIGHT) + test_lang.parse('M0-A SJ1-A M1-B') + + validator_0 = test_lang.validator_set.get_validator_by_name(0) + validator_1 = test_lang.validator_set.get_validator_by_name(1) + + justification = validator_1.justification() + + assert len(justification) == 2 + assert not isinstance(justification[validator_0], Block) + assert not isinstance(justification[validator_1], Block) + assert justification[validator_0] == test_lang.messages['A'].hash + assert justification[validator_1] == test_lang.messages['B'].hash + + +def test_justification_includes_justified_messages(): + test_lang = BlockchainTestLang(TEST_WEIGHT) + test_lang.parse('M0-A M0-B S1-B M1-C') + + validator_0 = test_lang.validator_set.get_validator_by_name(0) + validator_1 = test_lang.validator_set.get_validator_by_name(1) + + justification = validator_1.justification() + + assert len(justification) == 2 + assert test_lang.messages["A"].hash not in justification.values() + assert test_lang.messages["B"].hash not in justification.values() + assert test_lang.network.global_view.genesis_block.hash in justification.values() + assert justification[validator_1] == test_lang.messages['C'].hash + + test_lang.parse('SJ1-B') + + justification = validator_1.justification() + + assert len(justification) == 2 + assert justification[validator_0] == test_lang.messages['B'].hash + assert justification[validator_1] == test_lang.messages['C'].hash + + +def test_add_justified_message(): + test_lang = BlockchainTestLang(TEST_WEIGHT) + test_lang.parse('M0-A M0-B SJ1-A') + validator_0 = test_lang.validator_set.get_validator_by_name(0) + validator_1 = test_lang.validator_set.get_validator_by_name(1) + assert test_lang.messages['A'] in validator_0.view.justified_messages.values() + assert test_lang.messages['A'] in validator_1.view.justified_messages.values() + assert test_lang.messages['B'] in validator_0.view.justified_messages.values() + assert test_lang.messages['B'] not in validator_1.view.justified_messages.values() @pytest.mark.parametrize( - 'latest_messages', + 'test_string, justified_messages, unjustified_messages', [ - ({"face": 10}), - ({1: 10, 2: 30}), + ('M0-A M0-B S1-B', [['A', 'B'], []], [[], ['B']]), + ('M0-A M0-B S1-B SJ1-A', [['A', 'B'], ['A', 'B']], [[], []]), + ('M0-A M0-B M0-C S1-C S1-B', [['A', 'B', 'C'], []], [[], ['B', 'C']]), ] ) -def test_justification(latest_messages): - view = AbstractView() - view.latest_messages = latest_messages +def test_only_add_justified_messages(test_string, justified_messages, unjustified_messages): + test_lang = BlockchainTestLang(TEST_WEIGHT) + test_lang.parse(test_string) + + for validator in test_lang.validator_set: + idx = validator.name + + for message in justified_messages[idx]: + assert test_lang.messages[message] in validator.view.justified_messages.values() + assert test_lang.messages[message] not in validator.view.pending_messages.values() + + for message in unjustified_messages[idx]: + assert test_lang.messages[message] not in validator.view.justified_messages.values() + assert test_lang.messages[message] in validator.view.pending_messages.values() + + +@pytest.mark.parametrize( + 'weight, test_string, justified_messages, unjustified_messages, resolving_string', + [ + ( + TEST_WEIGHT, + 'M0-A M0-B S1-B', + [['A', 'B'], []], + [[], ['B']], + 'SJ1-A' + ), + ( + {0: 10, 1: 9, 2: 8}, + 'RR0-A M0-B SJ1-B M1-C S2-C', + [['A', 'B'], ['A', 'B'], []], + [[], [], ['C']], + 'S2-B' + ), + ( + TEST_WEIGHT, + 'M0-A SJ1-A M0-B M0-C M0-D M0-E S1-E', + [['A', 'B', 'C', 'D', 'E'], ['A'], []], + [[], [], ['E']], + 'S1-B S1-C S1-D' + ), + ( + TEST_WEIGHT, + 'M0-A SJ1-A M0-B M0-C M0-D M0-E S1-E', + [['A', 'B', 'C', 'D', 'E'], ['A'], []], + [[], [], ['E']], + 'S1-D S1-B S1-C' + ), + ] +) +def test_resolve_message_when_justification_arrives(weight, test_string, justified_messages, unjustified_messages, resolving_string): + test_lang = BlockchainTestLang(weight) + test_lang.parse(test_string) + + for validator in test_lang.validator_set: + idx = validator.name + + for message in justified_messages[idx]: + assert test_lang.messages[message] in validator.view.justified_messages.values() + assert test_lang.messages[message] not in validator.view.pending_messages.values() + + for message in unjustified_messages[idx]: + assert test_lang.messages[message] not in validator.view.justified_messages.values() + assert test_lang.messages[message] in validator.view.pending_messages.values() + + test_lang.parse(resolving_string) + + for validator in test_lang.validator_set: + idx = validator.name + + for message in justified_messages[idx]: + assert test_lang.messages[message] in validator.view.justified_messages.values() + assert test_lang.messages[message] not in validator.view.pending_messages.values() + + for message in unjustified_messages[idx]: + assert test_lang.messages[message] in validator.view.justified_messages.values() + assert test_lang.messages[message] not in validator.view.pending_messages.values() + + + +def test_multiple_messages_arriving_resolve(): + test_string = "M0-A SJ1-A M0-B M0-C M0-D M0-E M0-F S1-F" + test_lang = BlockchainTestLang(TEST_WEIGHT) + test_lang.parse(test_string) + + validator_1 = test_lang.validator_set.get_validator_by_name(1) + + assert len(validator_1.view.justified_messages) == 2 + assert len(validator_1.view.pending_messages) == 1 + assert test_lang.messages['F'] in validator_1.view.pending_messages.values() + + validator_1.receive_messages(test_lang.network.global_view.justified_messages.values()) - assert view.justification().latest_messages == latest_messages + assert len(validator_1.view.justified_messages) == 7 diff --git a/tests/simulations/conftest.py b/tests/simulations/conftest.py index 68517e9..8ce7b58 100644 --- a/tests/simulations/conftest.py +++ b/tests/simulations/conftest.py @@ -1,11 +1,20 @@ import pytest -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol from simulations.simulation_runner import SimulationRunner import simulations.utils as utils @pytest.fixture -def simulation_runner(validator_set): +def simulation_runner(protocol, validator_set, network): msg_gen = utils.message_maker('rand') - return SimulationRunner(validator_set, msg_gen, BlockchainProtocol, 20, 20, False, False) + return SimulationRunner( + validator_set, + msg_gen, + protocol, + network, + 20, + 20, + False, + False + ) diff --git a/tests/simulations/test_analyzer.py b/tests/simulations/test_analyzer.py index 7300809..f4d68a6 100644 --- a/tests/simulations/test_analyzer.py +++ b/tests/simulations/test_analyzer.py @@ -1,6 +1,7 @@ import pytest -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from casper.networks import NoDelayNetwork +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol from simulations.analyzer import Analyzer from simulations.simulation_runner import SimulationRunner @@ -16,12 +17,15 @@ ('nofinal', 2), ] ) -def test_num_messages(validator_set, mode, messages_generated_per_round): +def test_num_messages_genesis(generate_validator_set, genesis_protocol, mode, messages_generated_per_round): + validator_set = generate_validator_set(genesis_protocol) + network = NoDelayNetwork(validator_set, genesis_protocol) msg_gen = utils.message_maker(mode) simulation_runner = SimulationRunner( validator_set, msg_gen, BlockchainProtocol, + network, 100, 20, False, @@ -29,13 +33,14 @@ def test_num_messages(validator_set, mode, messages_generated_per_round): ) analyzer = Analyzer(simulation_runner) - # due to random_initialization - assert analyzer.num_messages() == len(validator_set) + assert analyzer.num_messages == 1 + potential_extra_messages = len(validator_set) - 1 for i in range(10): simulation_runner.step() - assert analyzer.num_messages() == \ - len(validator_set) + (i + 1) * messages_generated_per_round + messages_generated = 1 + (i + 1) * messages_generated_per_round + + assert analyzer.num_messages <= messages_generated + potential_extra_messages @pytest.mark.skip(reason="test not written") diff --git a/tests/simulations/test_simulation_runner.py b/tests/simulations/test_simulation_runner.py index d625481..1a0e3be 100644 --- a/tests/simulations/test_simulation_runner.py +++ b/tests/simulations/test_simulation_runner.py @@ -1,10 +1,11 @@ import sys import pytest -from casper.blockchain.blockchain_protocol import BlockchainProtocol -from casper.binary.binary_protocol import BinaryProtocol +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.binary.binary_protocol import BinaryProtocol from casper.network import Network +from casper.networks import StepNetwork from simulations.simulation_runner import SimulationRunner import simulations.utils as utils @@ -21,11 +22,13 @@ def test_new_simulation_runner(generate_validator_set, protocol, mode, rounds, report_interval): msg_gen = utils.message_maker(mode) validator_set = generate_validator_set(protocol) + network = StepNetwork(validator_set, protocol) simulation_runner = SimulationRunner( validator_set, msg_gen, protocol, + network, rounds, report_interval, False, @@ -76,40 +79,47 @@ def test_simulation_runner_step(simulation_runner): @pytest.mark.parametrize( - 'protocol, mode, messages_generated_per_round', + 'protocol, mode, messages_generated_per_round, potential_extra_messages', [ - (BlockchainProtocol, 'rand', 1), - (BlockchainProtocol, 'rrob', 1), - (BlockchainProtocol, 'full', 5), - (BlockchainProtocol, 'nofinal', 2), - (BinaryProtocol, 'rand', 1), - (BinaryProtocol, 'rrob', 1), - (BinaryProtocol, 'full', 5), - (BinaryProtocol, 'nofinal', 2), + (BlockchainProtocol, 'rand', 1, 4), + (BlockchainProtocol, 'rrob', 1, 4), + (BlockchainProtocol, 'full', 5, 4), + (BlockchainProtocol, 'nofinal', 2, 2), + (BinaryProtocol, 'rand', 1, 0), + (BinaryProtocol, 'rrob', 1, 0), + (BinaryProtocol, 'full', 5, 0), + (BinaryProtocol, 'nofinal', 2, 0), ] ) def test_simulation_runner_send_messages( generate_validator_set, protocol, mode, - messages_generated_per_round + messages_generated_per_round, + potential_extra_messages ): msg_gen = utils.message_maker(mode) validator_set = generate_validator_set(protocol) + network = StepNetwork(validator_set, protocol) simulation_runner = SimulationRunner( validator_set, msg_gen, protocol, + network, 100, 20, False, False ) - assert len(simulation_runner.network.global_view.messages) == len(validator_set) + if protocol == BlockchainProtocol: + assert len(simulation_runner.network.global_view.justified_messages) == 1 + if protocol == BinaryProtocol: + assert len(simulation_runner.network.global_view.justified_messages) == len(validator_set) + initial_message_length = len(simulation_runner.network.global_view.justified_messages) for i in range(10): simulation_runner.step() - assert len(simulation_runner.network.global_view.messages) == \ - (i + 1) * messages_generated_per_round + len(validator_set) + assert len(simulation_runner.network.global_view.justified_messages) <= \ + initial_message_length + potential_extra_messages + (i+1)*messages_generated_per_round diff --git a/tests/simulations/test_simulation_utils.py b/tests/simulations/test_simulation_utils.py index 12b7067..2e9572c 100644 --- a/tests/simulations/test_simulation_utils.py +++ b/tests/simulations/test_simulation_utils.py @@ -1,6 +1,6 @@ import pytest -from casper.blockchain.blockchain_protocol import BlockchainProtocol +from casper.protocols.blockchain.blockchain_protocol import BlockchainProtocol from casper.validator_set import ValidatorSet from simulations.utils import ( generate_random_gaussian_validator_set, @@ -40,12 +40,10 @@ def test_random_message_maker(validator_set): msg_gen = message_maker("rand") for i in range(20): - message_paths = msg_gen(validator_set) - assert len(message_paths) == 1 - for message_path in message_paths: - assert len(message_path) == 2 - for validator in message_path: - assert validator in validator_set + message_makers = msg_gen(validator_set) + assert len(message_makers) == 1 + for validator in message_makers: + assert validator in validator_set def test_round_robin_message_maker(validator_set): @@ -53,50 +51,39 @@ def test_round_robin_message_maker(validator_set): assert msg_gen.next_sender_index == 0 senders = validator_set.sorted_by_name() - receivers = senders[1:] + senders[0:1] for i in range(3): for j in range(len(validator_set)): - message_paths = msg_gen(validator_set) - assert len(message_paths) == 1 - message_path = message_paths[0] - assert len(message_path) == 2 - assert message_path[0] == senders[j] - assert message_path[1] == receivers[j] + message_makers = msg_gen(validator_set) + assert len(message_makers) == 1 + validator = message_makers[0] + assert validator == senders[j] @pytest.mark.parametrize( - 'weights, pairs', + 'weights', [ ( - {"jim": 1, "dan": 30}, - [["jim", "dan"], ["dan", "jim"]] + {"jim": 1, "dan": 30} ), ( - {0: 10, 1: 8, 2: 12}, - [ - [0, 1], [0, 2], [1, 2], - [1, 0], [2, 0], [2, 1] - ] + {0: 10, 1: 8, 2: 12} ), ] ) -def test_full_message_maker(weights, pairs): +def test_full_message_maker(weights): validator_set = ValidatorSet(weights) msg_gen = message_maker("full") - message_paths = msg_gen(validator_set) - for sender_name, receiver_name in pairs: - sender = validator_set.get_validator_by_name(sender_name) - receiver = validator_set.get_validator_by_name(receiver_name) - assert (sender, receiver) in message_paths + message_makers = msg_gen(validator_set) + assert len(validator_set) == len(message_makers) + assert set(validator_set.validators) == set(message_makers) def test_no_final_message_maker(validator_set): msg_gen = message_maker("nofinal") senders = validator_set.sorted_by_name() - receivers = senders[1:] + senders[0:1] for i in range(3): for j in range(len(validator_set)): @@ -105,13 +92,9 @@ def test_no_final_message_maker(validator_set): assert len(message_paths) == 2 # first rr message this round - first_message_path = message_paths[0] - assert len(first_message_path) == 2 - assert first_message_path[0] == senders[index] - assert first_message_path[1] == receivers[index] + first_validator = message_paths[0] + assert first_validator == senders[index] # second rr message this round - second_message_path = message_paths[1] - assert len(second_message_path) == 2 - assert second_message_path[0] == senders[(index + 1) % len(validator_set)] - assert second_message_path[1] == receivers[(index + 1) % len(validator_set)] + second_validator = message_paths[1] + assert second_validator == senders[(index + 1) % len(validator_set)] diff --git a/utils/priority_queue.py b/utils/priority_queue.py new file mode 100644 index 0000000..f01170f --- /dev/null +++ b/utils/priority_queue.py @@ -0,0 +1,6 @@ +from queue import PriorityQueue as PQ + + +class PriorityQueue(PQ): + def peek(self): + return self.queue[0]