Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Jan 4, 2025
2 parents 17d716b + f825a70 commit cd96ae9
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 124 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
zigzag-dse==3.8.0
zigzag-dse==3.8.1
rtree
deap
matplotlib
Expand Down
226 changes: 109 additions & 117 deletions stream/stages/generation/tiled_workload_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from collections import defaultdict
from copy import deepcopy
from math import ceil, prod
Expand All @@ -19,7 +20,6 @@
from stream.stages.stage import Stage, StageCallable
from stream.utils import contains_wildcard
from stream.workload.computation.computation_node import ComputationNode, LoopRanges
from stream.workload.dependency_propagation.concat_node import ConcatNode
from stream.workload.dependency_propagation.dummy_node import DummyNode
from stream.workload.dependency_propagation.propagation_node import PropagationNode
from stream.workload.node import Node
Expand All @@ -28,7 +28,7 @@

logger = logging.getLogger(__name__)

EDGE_T = tuple[ComputationNode, ComputationNode, dict]
EDGE_T = tuple[ComputationNode, ComputationNode, dict[str, Any]]


class TensorDimensionMismatchException(Exception):
Expand Down Expand Up @@ -645,19 +645,18 @@ def get_inter_edges_numpy(
consumer: ComputationNode,
):

numpy_tensors: dict[ComputationNode, dict[LayerOperand, NodeTensor]] = {}
numpy_tensors: dict[tuple[ComputationNode, LayerOperand], NodeTensor] = {}
all_inter_edges: list[tuple[ComputationNode, ComputationNode, dict[str, Any]]] = []

def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand):
"""And update the known tensors of computation nodes"""
if node in numpy_tensors:
tensor_cns = numpy_tensors[node]
if (node, dependent_operand) in numpy_tensors:
tensor = numpy_tensors[(node, dependent_operand)]
else:
tiles = self.tiles_dict[node]
tensor_cns = self.get_tensor_cns(node, tiles)
tensor = self.get_node_tensor(node, tiles, dependent_operand)
# Store result for later use
numpy_tensors[node] = tensor_cns
tensor = tensor_cns[dependent_operand]
numpy_tensors[(node, dependent_operand)] = tensor
return tensor

paths_between = list(self.workload.all_simple_paths(producer, consumer))
Expand All @@ -670,76 +669,33 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand)
), "No paths between producer and consumer found without ComputationNode in intermediates."

for path_between in paths_between:
ts = [time.time()]
# First node in the path is a ComputationNode, of which we extract the output operand dependency tensor
first_node = path_between[0]
assert isinstance(first_node, ComputationNode), "First node in path should be ComputationNode"
tensor = get_tensor_cn_for_op(first_node, dependent_operand=Constants.OUTPUT_LAYER_OP)
ts.append(time.time())

# Propagate through intermediate, non-computation nodes
relevant_axes = [False] * len(tensor.tensor_shape)
for i, node in enumerate(path_between[1:-1], start=1):
assert isinstance(node, PropagationNode), "Intermediate nodes should not be of type ComputationNode"
next_node = path_between[i + 1]
tensor = node.propagate(tensor, next_node)
tensor, relevant_axes = node.propagate(tensor, next_node, relevant_axes)

# Final node: Computation node
final_node: ComputationNode = path_between[-1] # type: ignore
assert isinstance(final_node, ComputationNode), "Last node in path should be ComputationNode"

# Find the operand for which this last node connects to its predecessor
dependent_operand = next(
op for op, dependent_node_id in final_node.input_operand_source.items() if dependent_node_id == node.id
)
inter_edges = self.get_inter_edges_hybrid(tensor, final_node, dependent_operand, relevant_axes)

# Error handling of shape mismatches in tensor propagation
def _get_final_tensor_alt_operand():
"""Error handling case 1: sources for `W` and `I` operand are swapped for this node
-> try the other one"""
try:
alt_operand = next(op for op in final_node.input_operand_source if op != dependent_operand)
except StopIteration:
# No alt operand was found -> we're still in trouble
raise TensorDimensionMismatchException
return get_tensor_cn_for_op(final_node, alt_operand)

def _get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeTensor):
"""Error handling case 2: dimensions of ComputationNode (`final_tensor`) were altered by stream
(e.g. to be properly divisible) but this is not reflected in `ConcatNode` with constant shape.
-> manually fix shape"""
if not any(isinstance(node, ConcatNode) for node in path_between[1:-1]):
raise TensorDimensionMismatchException(
"This function only solves the case of errors due to constant shapes in ConcatNode"
)

target_shape = final_tensor.tensor_shape
propagated_shape = tensor.tensor_shape
extension_axis = next(i for i in range(len(target_shape)) if target_shape[i] != propagated_shape[i])
extension_value = target_shape[extension_axis] - propagated_shape[extension_axis]
if extension_value <= 0:
raise TensorDimensionMismatchException(
"Propagated shape cannot be larger than (extended) found shape"
)
extension_shape = tuple(
val if i != extension_axis else extension_value for i, val in enumerate(target_shape)
)
return tensor.concat_with_empty(extension_shape, extension_axis, variable_input_first=False)

try: # Regular case
final_tensor = get_tensor_cn_for_op(final_node, dependent_operand)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except TensorDimensionMismatchException:
try: # Error case 1
final_tensor = _get_final_tensor_alt_operand()
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except TensorDimensionMismatchException:
try: # Error case 2
final_tensor = get_tensor_cn_for_op(final_node, dependent_operand)
tensor = _get_shape_inferred_propagated_tensor(tensor, final_tensor)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except TensorDimensionMismatchException:
# Error case 1 and 2 combined
final_tensor = _get_final_tensor_alt_operand()
tensor = _get_shape_inferred_propagated_tensor(tensor, final_tensor)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
ts.append(time.time())
ts_deltas = [ts[i] - ts[i - 1] for i in range(1, len(ts))]
ts_deltas_str = ", ".join([f"{delta:.3f}" for delta in ts_deltas])
logger.info(f"Path {path_between} time deltas: {ts_deltas_str}")

for producer, cons in inter_edges:
all_inter_edges.append(
Expand All @@ -754,6 +710,39 @@ def _get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: Node
)
return all_inter_edges

def get_inter_edges_hybrid(
self, tensor: NodeTensor, final_node: ComputationNode, op: LayerOperand, relevant_axes: list[bool]
):
"""This method obtains the tile dependencies between producers in tensor and the consumer final_node.
This is done by iterating through all consumer tiles,
for each consumer node we create a window and get all the producer nodes that produced this data window.
Args:
tensor (NodeTensor): A tensor containing for each position which CNs will produce it
final_node (ComputationNode): The node for which to get the inter-edges
operand (LayerOperand): The input operand of final_node for which to get the inter-edges
relevant_axes (list): A list of boolean values indicating which axes are relevant for the final_node
"""
inter_edges: set[tuple[ComputationNode, ComputationNode]] = []
dims = final_node.operand_dimensionality_order[op]
assert len(dims) == len(relevant_axes)
for consumer_tile in self.tiles_dict[final_node]:
relevant_loop_ranges = [consumer_tile.loop_ranges[dim] for dim in dims]
# Override loop ranges of irrelevant axes to only include a single slice
for i, relevant in enumerate(relevant_axes):
if not relevant:
relevant_loop_ranges[i] = (0, 1)
# Ellipsis adds the entire last axis for the extra dimension in NodeTensor
slices = tuple(slice(start, stop) for start, stop in relevant_loop_ranges) + (Ellipsis,)
sliced_tensor = tensor[slices]
producer_tiles = set(
prod
for prod in (elem for elem in sliced_tensor.flat.flat if elem and isinstance(elem, ComputationNode))
)
for producer_tile in producer_tiles:
inter_edges.append((producer_tile, consumer_tile))
return inter_edges

@staticmethod
def get_inter_edges_tensor_based(producer_output_tensor: NodeTensor, consumer_input_tensor: NodeTensor):
"""This method obtains the edges between a producer and consumer.
Expand All @@ -780,67 +769,62 @@ def get_inter_edges_tensor_based(producer_output_tensor: NodeTensor, consumer_in
inter_edges.add((producer, consumer))
return inter_edges

def get_tensor_cns(self, node: ComputationNode, tiles: list[ComputationNode]) -> dict[LayerOperand, NodeTensor]:
def get_node_tensor(
self,
node: ComputationNode,
tiles: list[ComputationNode],
op: LayerOperand,
) -> NodeTensor:
is_source_node = len(self.get_non_type_predecessors(node, [DummyNode])) == 0
variable_operands = [op for op in node.input_operands if op not in node.constant_operands] + [
node.output_operand
]
tensor_dims = {op: node.operand_dimensionality_order[op] for op in variable_operands}
tensor_dims = node.operand_dimensionality_order[op]
all_loop_dim_sizes = node.layer_dim_sizes + node.pr_layer_dim_sizes # union
tensor_shapes = {op: tuple(all_loop_dim_sizes[dim] for dim in dims) for (op, dims) in tensor_dims.items()}

# Initial arrays.
tensors_cns: dict[LayerOperand, NodeTensor] = {
op: NodeTensor.initialize_empty(shape) for (op, shape) in tensor_shapes.items()
}

# For each input operand iterate through the tiles in reverse order
# because we want the first cn with a dependency saved in the tensor
# For the output operand iterate through the tiles in regular order
# because we want the last CN that handles an output tensor window to be saved
for op, dims in tensor_dims.items():
tensor_shapes: tuple[int, ...] = tuple(all_loop_dim_sizes[dim] for dim in tensor_dims)

if op == node.output_operand:
ir_dims_output = node.loop_relevancy_info.get_ir_layer_dims(Constants.OUTPUT_LAYER_OP)
tile_list = tiles # list in regular order
should_add_to_tensor_list = [
all(tile.loop_ranges[ir_dim][1] >= node.loop_ranges[ir_dim][1] for ir_dim in ir_dims_output)
for tile in tile_list
]
precision = node.operand_precision[Constants.FINAL_OUTPUT_LAYER_OP]
else:
tile_list = list(reversed(tiles)) # list in reversed order
should_add_to_tensor_list = [True for _ in tile_list]
# if this layer is the first layer, we assume the inputs are streamed and "free"
precision = node.operand_precision[op] * (not is_source_node)

nb_unique_data_seen = 0
node_tensor = NodeTensor.initialize_empty(tensor_shapes)
for tile, should_add_to_tensor in zip(tile_list, should_add_to_tensor_list):
if not should_add_to_tensor:
continue # Skip if we're not at the max ir loop value for output

op_dim_ranges = [tile.loop_ranges[loop_dim] for loop_dim in tensor_dims]
op_dim_ranges_max_stop = tuple(tensor_shapes)
# start can be negative for padding which, makes np flip
window = tuple([slice(max(0, start), stop) for (start, stop) in op_dim_ranges])

# Count how many nans we have in this window, as this is the amount of unique data consumed/produced by
# this tile # TODO this call takes a loooong time, can we optimize this?
nb_unique_data_bits = node_tensor.get_nb_empty_elements(window) * precision
nb_unique_data_seen += nb_unique_data_bits
# Add this amount of unique data to the data produced/consumed unique by this tile
if op == node.output_operand:
ir_dims_output = node.loop_relevancy_info.get_ir_layer_dims(Constants.OUTPUT_LAYER_OP)
tile_list = tiles # list in regular order
should_add_to_tensor_list = [
all(tile.loop_ranges[ir_dim][1] >= node.loop_ranges[ir_dim][1] for ir_dim in ir_dims_output)
for tile in tile_list
]
# attr_to_add_to = "data_produced_unique"
precision = node.operand_precision[Constants.FINAL_OUTPUT_LAYER_OP]
tile.data_produced_unique += nb_unique_data_bits
else:
tile_list = list(reversed(tiles)) # list in reversed order
should_add_to_tensor_list = [True for _ in tile_list]
# attr_to_add_to = "data_consumed_unique"
# if this layer is the first layer, we assume the inputs are streamed and "free"
precision = node.operand_precision[op] * (not is_source_node)

nb_unique_data_seen = 0
for tile, should_add_to_tensor in zip(tile_list, should_add_to_tensor_list):
if not should_add_to_tensor:
continue # Skip if we're not at the max ir loop value for output
op_dim_ranges = [tile.loop_ranges[loop_dim] for loop_dim in dims]
op_dim_ranges_max_stop = tuple(tensor_shapes[op])
# start can be negative for padding which, makes np flip
window = tuple([slice(max(0, start), stop) for (start, stop) in op_dim_ranges])
# Count how many nans we have in this window, as this is the amount of unique data consumed/produced by
# this tile # TODO this call takes a loooong time, can we optimize this?
nb_unique_data_bits = tensors_cns[op].get_nb_empty_elements(window) * precision
nb_unique_data_seen += nb_unique_data_bits
if op == node.output_operand:
tile.data_produced_unique += nb_unique_data_bits
else:
tile.data_consumed_unique += nb_unique_data_bits
tile.data_consumed_unique += nb_unique_data_bits

# This is not guaranteed: tiles of nodes whose ranges have been extended can exceed the NodeTensor shape
# assert all(start < max_stop for (start, _), max_stop in zip(op_dim_ranges, op_dim_ranges_max_stop))
# Set this window of the tensor to indicate it will be consumed/produced by this tile
# NOTE assert is not guaranteed: tiles of nodes whose ranges have been extended can exceed the NodeTensor shape
# assert all(start < max_stop for (start, _), max_stop in zip(op_dim_ranges, op_dim_ranges_max_stop))

# Slices that exceed the max stop are reduced to a size-1 slice at `max_stop-1`
bounded_op_dim_ranges = tuple(
slice(max(0, min(max_stop - 1, start)), min(max_stop, stop))
for ((start, stop), max_stop) in zip(op_dim_ranges, op_dim_ranges_max_stop)
)
tensors_cns[op] = tensors_cns[op].extend_with_node(bounded_op_dim_ranges, tile)
# Slices that exceed the max stop are reduced to a size-1 slice at `max_stop-1`
bounded_op_dim_ranges = tuple(
slice(max(0, min(max_stop - 1, start)), min(max_stop, stop))
for ((start, stop), max_stop) in zip(op_dim_ranges, op_dim_ranges_max_stop)
)
node_tensor = node_tensor.extend_with_node(bounded_op_dim_ranges, tile)

if nb_unique_data_seen < (prod(tensor_shapes[op]) * precision):
logger.warning(f"Downsampling node detected: {node}, operand= {op}.")
Expand All @@ -852,9 +836,17 @@ def get_tensor_cns(self, node: ComputationNode, tiles: list[ComputationNode]) ->
# input operand with dimensionality_order = ['B', 'G', 'C', 'IY', 'IX']
# -> gets reduced to dimensionality_order = ['B', 'CH', 'IY', 'IX']
# (in this case the 'CH' represents the absolute "channel" dimension)
for op, tensor in tensors_cns.items():
tensors_cns[op] = node.reshape_operand_tensor(tensor, operand=op)
node_tensor = node.reshape_operand_tensor(node_tensor, operand=op)

return node_tensor

def get_node_tensors(self, node: ComputationNode, tiles: list[ComputationNode]) -> dict[LayerOperand, NodeTensor]:
variable_operands = [op for op in node.input_operands if op not in node.constant_operands] + [
node.output_operand
]
tensors_cns: dict[LayerOperand, NodeTensor] = {}
for op in variable_operands:
tensors_cns[op] = self.get_node_tensor(node, tiles, op)
return tensors_cns

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions stream/workload/dependency_propagation/slice_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
self.input_operand_source = {Constants.LAYER_OP_I: predecessor}
self.output_names = output_names

def propagate(self, tensor: NodeTensor, next_node: Node | None = None):
def propagate(self, tensor: NodeTensor, next_node: Node | None = None, relevant_axes: list[bool] = []):
"""Slice the tensor.
Currently assumes only one slice is created."""
return tensor.slice(starts=self.starts[0], ends=self.ends[0], axis=self.axes[0], steps=self.steps[0])
sliced_tensor = tensor.slice(starts=self.starts[0], ends=self.ends[0], axis=self.axes[0], steps=self.steps[0])
relevant_axes[self.axes[0]] = True
return sliced_tensor, relevant_axes
7 changes: 5 additions & 2 deletions stream/workload/dependency_propagation/split_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.input_operand_source = {Constants.LAYER_OP_I: predecessor}
self.output_names = output_names

def propagate(self, tensor: NodeTensor, next_node: Node):
def propagate(self, tensor: NodeTensor, next_node: Node, relevant_axes: list[bool]):
"""Split the tensor back to the representation needed for producer/consumer."""

# Numpy requires the indices where to split instead of the sizes of the resulting splits
Expand All @@ -52,5 +52,8 @@ def propagate(self, tensor: NodeTensor, next_node: Node):
f"Cannot find this nodes' ({self.name}) outputs {self.output_names} in next nodes' inputs {next_node.input_names}"
)

# Update the relevant_dims with the axis involved in the split
relevant_axes[self.axis] = True

output_tensor = output_tensors[index]
return output_tensor
return output_tensor, relevant_axes
Loading

0 comments on commit cd96ae9

Please sign in to comment.