Skip to content

Commit

Permalink
in NodeTensor: if slice > tensor.shape, assign dependency at last siz…
Browse files Browse the repository at this point in the history
…e-1 slice
  • Loading branch information
RobinGeens committed Jan 4, 2025
1 parent 829f82d commit 17d716b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
29 changes: 15 additions & 14 deletions stream/stages/generation/tiled_workload_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def pad_until_divisible(layer_dim: LayerDim, n: int) -> int:
logger.warning(f"Padded layer dimension {dim}: {size} -> {new_size} to be divisible by tiling factors")

# Save these extended sizes for later
extended_layer_dim_sizes = deepcopy(tile_attrs.layer_dim_sizes)
original_node.extended_layer_dim_sizes = deepcopy(tile_attrs.layer_dim_sizes)

# Take away the outer_temporal_loops to create tiled CNs for this node
for loop in outer_temporal_loops:
Expand Down Expand Up @@ -327,7 +327,7 @@ def pad_until_divisible(layer_dim: LayerDim, n: int) -> int:
tiles: list[ComputationNode] = []
tensors: list[Tensor] = []
group_id_manager = GroupIdManager(
layer_dim_sizes=extended_layer_dim_sizes,
layer_dim_sizes=original_node.extended_layer_dim_sizes,
intra_core_tiling=original_node.intra_core_tiling,
inter_core_tiling=original_node.inter_core_tiling,
)
Expand Down Expand Up @@ -806,12 +806,12 @@ def get_tensor_cns(self, node: ComputationNode, tiles: list[ComputationNode]) ->
all(tile.loop_ranges[ir_dim][1] >= node.loop_ranges[ir_dim][1] for ir_dim in ir_dims_output)
for tile in tile_list
]
attr_to_add_to = "data_produced_unique"
# attr_to_add_to = "data_produced_unique"
precision = node.operand_precision[Constants.FINAL_OUTPUT_LAYER_OP]
else:
tile_list = list(reversed(tiles)) # list in reversed order
should_add_to_tensor_list = [True for _ in tile_list]
attr_to_add_to = "data_consumed_unique"
# attr_to_add_to = "data_consumed_unique"
# if this layer is the first layer, we assume the inputs are streamed and "free"
precision = node.operand_precision[op] * (not is_source_node)

Expand All @@ -824,19 +824,20 @@ def get_tensor_cns(self, node: ComputationNode, tiles: list[ComputationNode]) ->
# start can be negative for padding which, makes np flip
window = tuple([slice(max(0, start), stop) for (start, stop) in op_dim_ranges])
# Count how many nans we have in this window, as this is the amount of unique data consumed/produced by
# this tile
# this tile # TODO this call takes a loooong time, can we optimize this?
nb_unique_data_bits = tensors_cns[op].get_nb_empty_elements(window) * precision
nb_unique_data_seen += nb_unique_data_bits
# Add this amount of unique data to the "data_consumed_unique" or "data_produced_unique" depending on
# input/output operand
setattr(
tile,
attr_to_add_to,
getattr(tile, attr_to_add_to) + nb_unique_data_bits,
)
# Set this window of the tensor to indicate it will be consumed/produced by this tile
if op == node.output_operand:
tile.data_produced_unique += nb_unique_data_bits
else:
tile.data_consumed_unique += nb_unique_data_bits

# This is not guaranteed: tiles of nodes whose ranges have been extended can exceed the NodeTensor shape
# assert all(start < max_stop for (start, _), max_stop in zip(op_dim_ranges, op_dim_ranges_max_stop))

# Slices that exceed the max stop are reduced to a size-1 slice at `max_stop-1`
bounded_op_dim_ranges = tuple(
slice(max(0, start), min(max_stop, stop))
slice(max(0, min(max_stop - 1, start)), min(max_stop, stop))
for ((start, stop), max_stop) in zip(op_dim_ranges, op_dim_ranges_max_stop)
)
tensors_cns[op] = tensors_cns[op].extend_with_node(bounded_op_dim_ranges, tile)
Expand Down
3 changes: 3 additions & 0 deletions stream/workload/computation/computation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zigzag.utils import hash_sha512
from zigzag.visualization.results.plot_cme import shorten_onnx_layer_name
from zigzag.workload.layer_attributes import (
LayerDimSizes,
LayerPadding,
)
from zigzag.workload.layer_node import LayerNode, LayerNodeAttributes
Expand Down Expand Up @@ -83,6 +84,8 @@ def __init__(
self.operand_dimensionality_order: dict[LayerOperand, list[LayerDim]] = {
layer_op: self.equation.get_r_layer_dims(layer_op) for layer_op in self.equation.get_contained_operands()
}
# Sizes can be extended to fit division factors
self.extended_layer_dim_sizes: LayerDimSizes = deepcopy(self.layer_dim_sizes)

# adds pr dimensions loop ranges to self.loop_ranges
self.calculate_pr_loop_ranges()
Expand Down

0 comments on commit 17d716b

Please sign in to comment.