Skip to content
This repository has been archived by the owner on Jan 6, 2025. It is now read-only.

Add blockchain sharding #169

Merged
merged 11 commits into from
Feb 24, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ run-nofinal:

run-binary:
venv/bin/python casper.py rand --protocol binary --report-interval 3

run-sharding:
venv/bin/python casper.py rand --protocol sharding --validators 14 --report-interval 3
Empty file.
57 changes: 57 additions & 0 deletions casper/protocols/sharding/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""The block module implements the message data structure for a sharded blockchain"""
from casper.message import Message


class Block(Message):
"""Message data structure for a sharded blockchain"""

@classmethod
def is_valid_estimate(cls, estimate):
if not isinstance(estimate, dict):
return False
if not isinstance(estimate['prev_blocks'], set):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

estimate['prev_blocks'] will throw a KeyError if 'prev_blocks' is not set. Should probably check for inclusion of keys and then if the value is in fact a set. It's probably more likely (or at least just as likely) that the key is forgotten as the value being incorrect.

return False
if not isinstance(estimate['shard_ids'], set):
return False
return True

def on_shard(self, shard_id):
return shard_id in self.estimate['shard_ids']

def prev_block(self, shard_id):
"""Returns the previous block on the shard: shard_id
Throws a KeyError if there is no previous block"""
if shard_id not in self.estimate['shard_ids']:
raise KeyError("No previous block on that shard")

for block in self.estimate['prev_blocks']:
if block is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth a comment on what is going on here.

return None

if block.on_shard(shard_id):
return block

raise KeyError("Should have found previous block on shard!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to add why we should have found a previous block:

"shard_id in estimate['shard_ids']. Should have found previous block on shard!"


@property
def is_merge_block(self):
return len(self.estimate['shard_ids']) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we're in binary sharding so we are checking for == rather than >=, but if we aren't checking for >= here, we should probably enforce len(estimate['shard_ids']) in is_valid_estimate.

Maybe consider adding a CONSTANT to the class that enforces this max: MAX_MERGE_BLOCK_SIZE or something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add such a constant then this method would be checking that the len in range(2, MAX_MERGE_BLOCK_SIZE+1)


def conflicts_with(self, message):
"""Returns true if self is not in the prev blocks of other_message"""
assert isinstance(message, Block), "...expected a block"

return not self.is_in_blockchain(message, '')

def is_in_blockchain(self, block, shard_id):
"""Could be a zero generation ancestor!"""
if not block:
return False

if not block.on_shard(shard_id):
return False

if self == block:
return True

return self.is_in_blockchain(block.prev_block(shard_id), shard_id)
79 changes: 79 additions & 0 deletions casper/protocols/sharding/forkchoice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""The forkchoice module implements the estimator function a blockchain"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sharded blockchain"



def get_max_weight_indexes(scores):
"""Returns the keys that map to the max value in a dict.
The max value must be greater than zero."""

max_score = max(scores.values())

assert max_score != 0, "max_score of a block should never be zero"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe max_score > 0, "max_score should be greater than zero"


max_weight_estimates = {e for e in scores if scores[e] == max_score}

return max_weight_estimates


def get_scores(starting_block, latest_messages, shard_id):
"""Returns a dict of block => weight"""
scores = dict()

for validator, current_block in latest_messages.items():
if not current_block.on_shard(shard_id):
continue

while current_block and current_block != starting_block:
scores[current_block] = scores.get(current_block, 0) + validator.weight
current_block = current_block.prev_block(shard_id)

return scores


def get_shard_fork_choice(starting_block, children, latest_messages, shard_id):
"""Get the forkchoice for a specific shard"""

scores = get_scores(starting_block, latest_messages, shard_id)

best_block = starting_block
while best_block in children:
curr_scores = dict()
max_score = 0
for child in children[best_block]:
if not child.on_shard(shard_id):
continue # we only select children on the same shard
curr_scores[child] = scores.get(child, 0)
max_score = max(curr_scores[child], max_score)

# If no child on shard, or 0 weight block, stop
if max_score == 0:
break

max_weight_children = get_max_weight_indexes(curr_scores)

assert len(max_weight_children) == 1, "... there should be no ties!"

best_block = max_weight_children.pop()

return best_block


def get_all_shards_fork_choice(starting_blocks, children, latest_messages_on_shard):
"""Returns a dict of shard_id -> forkchoice.
Starts from starting block for shard, and stops when it reaches tip"""

# for any shard we have latest messages on, we should have a starting block
for key in starting_blocks.keys():
assert key in latest_messages_on_shard
for key in latest_messages_on_shard.keys():
assert key in latest_messages_on_shard

shards_forkchoice = {
shard_id: get_shard_fork_choice(
starting_blocks[shard_id],
children,
latest_messages_on_shard[shard_id],
shard_id
) for shard_id in starting_blocks
}

return shards_forkchoice
105 changes: 105 additions & 0 deletions casper/protocols/sharding/sharding_plot_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""The blockchain plot tool implements functions for plotting blockchain data structures"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sharded blockchain"


from casper.plot_tool import PlotTool
from casper.safety_oracles.clique_oracle import CliqueOracle
import casper.utils as utils


class ShardingPlotTool(PlotTool):
"""The module contains functions for plotting a blockchain data structure"""

def __init__(self, display, save, view, validator_set):
super().__init__(display, save, 's')
self.view = view
self.validator_set = validator_set
self.starting_blocks = self.view.starting_blocks
self.message_fault_tolerance = dict()

self.blockchain = []
self.communications = []

self.block_fault_tolerance = {}
self.message_labels = {}
self.justifications = {
validator: []
for validator in validator_set
}

def update(self, new_messages=None):
"""Updates displayable items with new messages and paths"""
return

if new_messages is None:
new_messages = []

self._update_new_justifications(new_messages)
self._update_blockchain(new_messages)
self._update_block_fault_tolerance()
self._update_message_labels(new_messages)

def plot(self):
"""Builds relevant edges to display and creates next viewgraph using them"""
return
best_chain_edge = self.get_best_chain()

validator_chain_edges = self.get_validator_chains()

edgelist = []
edgelist.append(utils.edge(self.blockchain, 2, 'grey', 'solid'))
edgelist.append(utils.edge(self.communications, 1, 'black', 'dotted'))
edgelist.append(best_chain_edge)
edgelist.extend(validator_chain_edges)

self.next_viewgraph(
self.view,
self.validator_set,
edges=edgelist,
message_colors=self.block_fault_tolerance,
message_labels=self.message_labels
)

def get_best_chain(self):
"""Returns an edge made of the global forkchoice to genesis"""
best_message = self.view.estimate()
best_chain = utils.build_chain(best_message, None)[:-1]
return utils.edge(best_chain, 5, 'red', 'solid')

def get_validator_chains(self):
"""Returns a list of edges main from validators current forkchoice to genesis"""
vals_chain_edges = []
for validator in self.validator_set:
chain = utils.build_chain(validator.my_latest_message(), None)[:-1]
vals_chain_edges.append(utils.edge(chain, 2, 'blue', 'solid'))

return vals_chain_edges

def _update_new_justifications(self, new_messages):
for message in new_messages:
sender = message.sender
for validator in message.justification:
last_message = self.view.justified_messages[message.justification[validator]]
# only show if new justification
if last_message not in self.justifications[sender]:
self.communications.append([last_message, message])
self.justifications[sender].append(last_message)

def _update_blockchain(self, new_messages):
for message in new_messages:
if message.estimate is not None:
self.blockchain.append([message, message.estimate])

def _update_message_labels(self, new_messages):
for message in new_messages:
self.message_labels[message] = message.sequence_number

def _update_block_fault_tolerance(self):
tip = self.view.estimate()

while tip and self.block_fault_tolerance.get(tip, 0) != len(self.validator_set) - 1:
oracle = CliqueOracle(tip, self.view, self.validator_set)
fault_tolerance, num_node_ft = oracle.check_estimate_safety()

if fault_tolerance > 0:
self.block_fault_tolerance[tip] = num_node_ft

tip = tip.estimate
56 changes: 56 additions & 0 deletions casper/protocols/sharding/sharding_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from casper.protocols.sharding.sharding_view import ShardingView
from casper.protocols.sharding.block import Block
from casper.protocols.sharding.sharding_plot_tool import ShardingPlotTool
from casper.protocol import Protocol


class ShardingProtocol(Protocol):
View = ShardingView
Message = Block
PlotTool = ShardingPlotTool

shard_genesis_blocks = dict()
curr_shard_idx = 0
curr_shard_ids = ['']

"""Shard ID's look like this:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sick comment 😸

''
/ \
'0' '1'
/ \ / \
'00''01''10''11'


Blocks can be merge mined between shards if
there is an edge between shards
That is, for ids shard_1 and shard_2, there can be a merge block if
abs(len(shard_1) - len(shard_2)) = 1 AND
for i in range(min(len(shard_1), len(shard_2))):
shard_1[i] = shard_2[i]
"""

@classmethod
def initial_message(cls, validator):
"""Returns a dict from shard_id -> shard genesis block"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is incorrect. Returns a Block

shard_id = cls.get_new_shard_id()

estimate = {'prev_blocks': set([None]), 'shard_ids': set([shard_id])}
cls.shard_genesis_blocks[''] = Block(estimate, dict(), validator, -1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on right here? Why do we use '' as the key each time? Ultimately we are just storing the last initial_message and all the others are overwritten, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Yeah, meant to save the blocks (in case we don't want to number of shards to be equal to the number of validators :) ).


return cls.shard_genesis_blocks['']

@classmethod
def get_new_shard_id(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, I see these functions often named with "next" so get_next_shard_id. That said... follow your heart.

new_id = cls.curr_shard_ids[cls.curr_shard_idx]
cls.curr_shard_idx += 1

if cls.curr_shard_idx == len(cls.curr_shard_ids):
new_ids = []
for shard_id in cls.curr_shard_ids:
new_ids.append(shard_id + '0')
new_ids.append(shard_id + '1')

cls.curr_shard_idx = 0
cls.curr_shard_ids = new_ids

return new_id
Loading