diff --git a/Makefile b/Makefile index 5d6aa72..c5cff98 100644 --- a/Makefile +++ b/Makefile @@ -30,3 +30,6 @@ run-nofinal: run-binary: venv/bin/python casper.py rand --protocol binary --report-interval 3 + +run-sharding: + venv/bin/python casper.py rand --protocol sharding --validators 14 --report-interval 3 diff --git a/casper/protocols/sharding/__init__.py b/casper/protocols/sharding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casper/protocols/sharding/block.py b/casper/protocols/sharding/block.py new file mode 100644 index 0000000..1845d9b --- /dev/null +++ b/casper/protocols/sharding/block.py @@ -0,0 +1,63 @@ +"""The block module implements the message data structure for a sharded blockchain""" +from casper.message import Message +NUM_MERGE_SHARDS = 2 + + +class Block(Message): + """Message data structure for a sharded blockchain""" + + @classmethod + def is_valid_estimate(cls, estimate): + for key in ['prev_blocks', 'shard_ids']: + if key not in estimate: + return False + if not isinstance(estimate[key], set): + return False + return True + + def on_shard(self, shard_id): + return shard_id in self.estimate['shard_ids'] + + def prev_block(self, shard_id): + """Returns the previous block on the shard: shard_id + Throws a KeyError if there is no previous block""" + if shard_id not in self.estimate['shard_ids']: + raise KeyError("No previous block on that shard") + + for block in self.estimate['prev_blocks']: + # if this block is the genesis, previous is None + if block is None: + return None + + # otherwise, return the previous block on that shard + if block.on_shard(shard_id): + return block + + raise KeyError("Block on {}, but has no previous block on that shard!".format(shard_id)) + + @property + def is_merge_block(self): + return len(self.estimate['shard_ids']) == NUM_MERGE_SHARDS + + @property + def is_genesis_block(self): + return None in self.estimate['prev_blocks'] + + def conflicts_with(self, message): + """Returns true if self is not in the prev blocks of other_message""" + assert isinstance(message, Block), "...expected a block" + + return not self.is_in_blockchain(message, '') + + def is_in_blockchain(self, block, shard_id): + """Could be a zero generation ancestor!""" + if not block: + return False + + if not block.on_shard(shard_id): + return False + + if self == block: + return True + + return self.is_in_blockchain(block.prev_block(shard_id), shard_id) diff --git a/casper/protocols/sharding/forkchoice.py b/casper/protocols/sharding/forkchoice.py new file mode 100644 index 0000000..f8124ae --- /dev/null +++ b/casper/protocols/sharding/forkchoice.py @@ -0,0 +1,88 @@ +"""The forkchoice module implements the estimator function a sharded blockchain""" + + +def get_max_weight_indexes(scores): + """Returns the keys that map to the max value in a dict. + The max value must be greater than zero.""" + + max_score = max(scores.values()) + + assert max_score > 0, "max_score should be greater than zero" + + max_weight_estimates = {e for e in scores if scores[e] == max_score} + + return max_weight_estimates + + +def get_scores(starting_block, latest_messages, shard_id): + """Returns a dict of block => weight""" + scores = dict() + + for validator, current_block in latest_messages.items(): + if not current_block.on_shard(shard_id): + continue + + while current_block and current_block != starting_block: + scores[current_block] = scores.get(current_block, 0) + validator.weight + current_block = current_block.prev_block(shard_id) + + return scores + + +def get_shard_fork_choice(starting_block, children, latest_messages, shard_id): + """Get the forkchoice for a specific shard""" + + scores = get_scores(starting_block, latest_messages, shard_id) + + best_block = starting_block + while best_block in children: + curr_scores = dict() + max_score = 0 + for child in children[best_block]: + if not child.on_shard(shard_id): + continue # we only select children on the same shard + # can't pick a child that a merge block with a higher shard + if child.is_merge_block: + not_in_forkchoice = False + for shard in child.estimate['shard_ids']: + if len(shard) < len(shard_id): + not_in_forkchoice = True + break + if not_in_forkchoice: + continue + curr_scores[child] = scores.get(child, 0) + max_score = max(curr_scores[child], max_score) + + # If no child on shard, or 0 weight block, stop + if max_score == 0: + break + + max_weight_children = get_max_weight_indexes(curr_scores) + + assert len(max_weight_children) == 1, "... there should be no ties!" + + best_block = max_weight_children.pop() + + return best_block + + +def get_all_shards_fork_choice(starting_blocks, children, latest_messages_on_shard): + """Returns a dict of shard_id -> forkchoice. + Starts from starting block for shard, and stops when it reaches tip""" + + # for any shard we have latest messages on, we should have a starting block + for key in starting_blocks.keys(): + assert key in latest_messages_on_shard + for key in latest_messages_on_shard.keys(): + assert key in latest_messages_on_shard + + shards_forkchoice = { + shard_id: get_shard_fork_choice( + starting_blocks[shard_id], + children, + latest_messages_on_shard[shard_id], + shard_id + ) for shard_id in starting_blocks + } + + return shards_forkchoice diff --git a/casper/protocols/sharding/sharding_plot_tool.py b/casper/protocols/sharding/sharding_plot_tool.py new file mode 100644 index 0000000..a30c7de --- /dev/null +++ b/casper/protocols/sharding/sharding_plot_tool.py @@ -0,0 +1,105 @@ +"""The blockchain plot tool implements functions for plotting sharded blockchain data structures""" + +from casper.plot_tool import PlotTool +from casper.safety_oracles.clique_oracle import CliqueOracle +import casper.utils as utils + + +class ShardingPlotTool(PlotTool): + """The module contains functions for plotting a blockchain data structure""" + + def __init__(self, display, save, view, validator_set): + super().__init__(display, save, 's') + self.view = view + self.validator_set = validator_set + self.starting_blocks = self.view.starting_blocks + self.message_fault_tolerance = dict() + + self.blockchain = [] + self.communications = [] + + self.block_fault_tolerance = {} + self.message_labels = {} + self.justifications = { + validator: [] + for validator in validator_set + } + + def update(self, new_messages=None): + """Updates displayable items with new messages and paths""" + return + + if new_messages is None: + new_messages = [] + + self._update_new_justifications(new_messages) + self._update_blockchain(new_messages) + self._update_block_fault_tolerance() + self._update_message_labels(new_messages) + + def plot(self): + """Builds relevant edges to display and creates next viewgraph using them""" + return + best_chain_edge = self.get_best_chain() + + validator_chain_edges = self.get_validator_chains() + + edgelist = [] + edgelist.append(utils.edge(self.blockchain, 2, 'grey', 'solid')) + edgelist.append(utils.edge(self.communications, 1, 'black', 'dotted')) + edgelist.append(best_chain_edge) + edgelist.extend(validator_chain_edges) + + self.next_viewgraph( + self.view, + self.validator_set, + edges=edgelist, + message_colors=self.block_fault_tolerance, + message_labels=self.message_labels + ) + + def get_best_chain(self): + """Returns an edge made of the global forkchoice to genesis""" + best_message = self.view.estimate() + best_chain = utils.build_chain(best_message, None)[:-1] + return utils.edge(best_chain, 5, 'red', 'solid') + + def get_validator_chains(self): + """Returns a list of edges main from validators current forkchoice to genesis""" + vals_chain_edges = [] + for validator in self.validator_set: + chain = utils.build_chain(validator.my_latest_message(), None)[:-1] + vals_chain_edges.append(utils.edge(chain, 2, 'blue', 'solid')) + + return vals_chain_edges + + def _update_new_justifications(self, new_messages): + for message in new_messages: + sender = message.sender + for validator in message.justification: + last_message = self.view.justified_messages[message.justification[validator]] + # only show if new justification + if last_message not in self.justifications[sender]: + self.communications.append([last_message, message]) + self.justifications[sender].append(last_message) + + def _update_blockchain(self, new_messages): + for message in new_messages: + if message.estimate is not None: + self.blockchain.append([message, message.estimate]) + + def _update_message_labels(self, new_messages): + for message in new_messages: + self.message_labels[message] = message.sequence_number + + def _update_block_fault_tolerance(self): + tip = self.view.estimate() + + while tip and self.block_fault_tolerance.get(tip, 0) != len(self.validator_set) - 1: + oracle = CliqueOracle(tip, self.view, self.validator_set) + fault_tolerance, num_node_ft = oracle.check_estimate_safety() + + if fault_tolerance > 0: + self.block_fault_tolerance[tip] = num_node_ft + + tip = tip.estimate diff --git a/casper/protocols/sharding/sharding_protocol.py b/casper/protocols/sharding/sharding_protocol.py new file mode 100644 index 0000000..5749116 --- /dev/null +++ b/casper/protocols/sharding/sharding_protocol.py @@ -0,0 +1,56 @@ +from casper.protocols.sharding.sharding_view import ShardingView +from casper.protocols.sharding.block import Block +from casper.protocols.sharding.sharding_plot_tool import ShardingPlotTool +from casper.protocol import Protocol + + +class ShardingProtocol(Protocol): + View = ShardingView + Message = Block + PlotTool = ShardingPlotTool + + shard_genesis_blocks = dict() + curr_shard_idx = 0 + curr_shard_ids = [''] + + """Shard ID's look like this: + '' + / \ + '0' '1' + / \ / \ + '00''01''10''11' + + + Blocks can be merge mined between shards if + there is an edge between shards + That is, for ids shard_1 and shard_2, there can be a merge block if + abs(len(shard_1) - len(shard_2)) = 1 AND + for i in range(min(len(shard_1), len(shard_2))): + shard_1[i] = shard_2[i] + """ + + @classmethod + def initial_message(cls, validator): + """Returns a starting block for a shard""" + shard_id = cls.get_next_shard_id() + + estimate = {'prev_blocks': set([None]), 'shard_ids': set([shard_id])} + cls.shard_genesis_blocks[shard_id] = Block(estimate, dict(), validator, -1, 0) + + return cls.shard_genesis_blocks[''] + + @classmethod + def get_next_shard_id(cls): + next_id = cls.curr_shard_ids[cls.curr_shard_idx] + cls.curr_shard_idx += 1 + + if cls.curr_shard_idx == len(cls.curr_shard_ids): + next_ids = [] + for shard_id in cls.curr_shard_ids: + next_ids.append(shard_id + '0') + next_ids.append(shard_id + '1') + + cls.curr_shard_idx = 0 + cls.curr_shard_ids = next_ids + + return next_id diff --git a/casper/protocols/sharding/sharding_view.py b/casper/protocols/sharding/sharding_view.py new file mode 100644 index 0000000..dada971 --- /dev/null +++ b/casper/protocols/sharding/sharding_view.py @@ -0,0 +1,139 @@ +"""The sharding view module extends a view for sharded blockchain data structure""" +import random as r + +from casper.abstract_view import AbstractView +import casper.protocols.sharding.forkchoice as forkchoice + + +class ShardingView(AbstractView): + """A view class that keeps track of children, latest_messages_on_shard and starting_blocks""" + def __init__(self, messages=None, shard_genesis_block=None): + self.children = dict() + + self.shard_genesis_blocks = dict() # shard_id -> genesis for shard + self.starting_blocks = dict() # shard_id -> starting block for forkchoice + self.latest_messages_on_shard = dict() # shard_id -> validator -> message + + self.select_shards = self.select_random_shards + + if shard_genesis_block: + for shard_id in shard_genesis_block.estimate['shard_ids']: + self.shard_genesis_blocks[shard_id] = shard_genesis_block + self.starting_blocks[shard_id] = shard_genesis_block + + super().__init__(messages) + + def estimate(self): + """Returns the current forkchoice in this view""" + shards_forkchoice = dict() + + for shard_id in sorted(self.starting_blocks): + shard_tip = forkchoice.get_shard_fork_choice( + self.starting_blocks[shard_id], + self.children, + self.latest_messages_on_shard[shard_id], + shard_id + ) + shards_forkchoice[shard_id] = shard_tip + + left_child_shard = shard_id + '0' + right_child_shard = shard_id + '1' + if left_child_shard in self.starting_blocks: + self.set_child_starting_block(shard_tip, shard_id, left_child_shard) + if right_child_shard in self.starting_blocks: + self.set_child_starting_block(shard_tip, shard_id, right_child_shard) + + self.check_forkchoice_atomicity(shards_forkchoice) + + shards_to_build_on = self.select_shards(shards_forkchoice) + return {'prev_blocks': {shards_forkchoice[shard_id] for shard_id in shards_to_build_on}, + 'shard_ids': shards_to_build_on} + + def set_child_starting_block(self, tip_block, parent_id, child_id): + """Changes the starting block for the forkchoice of a shard""" + child_merge_block = self.previous_merge_block_on_shard( + tip_block, + parent_id, + child_id + ) + + if child_merge_block: + self.starting_blocks[child_id] = child_merge_block + else: + self.starting_blocks[child_id] = self.shard_genesis_blocks[child_id] + + def select_random_shards(self, shards_forkchoice): + """Randomly selects a shard to build on, and sometimes selects another child shard""" + shards_to_build_on = [r.choice(list(self.starting_blocks.keys()))] + if r.choice([True, False]): + child_shard_id = shards_to_build_on[0] + str(r.randint(0, 1)) + if child_shard_id in self.starting_blocks: + shards_to_build_on.append(child_shard_id) + + return set(shards_to_build_on) + + def check_forkchoice_atomicity(self, shards_forkchoice): + """Asserts that if a merge block is in the forkchoice for a parent chain + then it is in the forkchoice for a child chain""" + print("Checking merge block atomicity") + for shard_id in sorted(shards_forkchoice): + tip = shards_forkchoice[shard_id] + + for child_shard_id in [shard_id + '0', shard_id + '1']: + if child_shard_id not in shards_forkchoice: + continue + + merge_block = self.previous_merge_block_on_shard( + tip, + shard_id, + child_shard_id + ) + if merge_block: + assert merge_block.is_in_blockchain( + shards_forkchoice[child_shard_id], + child_shard_id + ) + print("Passed") + + def previous_merge_block_on_shard(self, starting_block, block_shard_id, merge_shard): + """Get the most recent merge block between block_shard_id and merge_shard + Starts from starting_block""" + assert starting_block.on_shard(block_shard_id) + current_block = starting_block + while current_block: + if current_block.on_shard(merge_shard): + assert current_block.is_merge_block + return current_block + current_block = current_block.prev_block(block_shard_id) + return None + + def update_safe_estimates(self, validator_set): + """Checks safety on messages in views forkchoice, and updates last_finalized_block""" + # check the safety of the top shard! + pass + + def _update_protocol_specific_view(self, message): + """Given a now justified message, updates children and when_recieved""" + assert message.hash in self.justified_messages, "...should not have seen message!" + + # set starting messages! ::)) + if message.is_genesis_block: + for shard_id in message.estimate['shard_ids']: + self.shard_genesis_blocks[shard_id] = message + self.starting_blocks[shard_id] = message + + # update the latest_messages + for shard_id in message.estimate['shard_ids']: + if shard_id not in self.latest_messages_on_shard: + self.latest_messages_on_shard[shard_id] = dict() + latest_messages = self.latest_messages_on_shard[shard_id] + if message.sender not in latest_messages: + latest_messages[message.sender] = message + elif latest_messages[message.sender].sequence_number < message.sequence_number: + latest_messages[message.sender] = message + + # update children dictonary + for parent in message.estimate['prev_blocks']: + if parent not in self.children: + self.children[parent] = set() + self.children[parent].add(message) diff --git a/simulations/utils.py b/simulations/utils.py index cb26a40..5f06b32 100644 --- a/simulations/utils.py +++ b/simulations/utils.py @@ -21,6 +21,7 @@ from casper.protocols.integer.integer_protocol import IntegerProtocol from casper.protocols.order.order_protocol import OrderProtocol from casper.protocols.concurrent.concurrent_protocol import ConcurrentProtocol +from casper.protocols.sharding.sharding_protocol import ShardingProtocol from casper.validator_set import ValidatorSet @@ -37,7 +38,8 @@ 'binary': BinaryProtocol, 'integer': IntegerProtocol, 'order': OrderProtocol, - 'concurrent': ConcurrentProtocol + 'concurrent': ConcurrentProtocol, + 'sharding': ShardingProtocol } SELECT_MESSAGE_MODE = {