Skip to content

Commit

Permalink
Show shortened layer name in html schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Aug 1, 2024
1 parent 49623fd commit 29b6245
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion stream/classes/workload/computation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zigzag.datatypes import Constants, LayerDim, LayerOperand, MemoryOperand
from zigzag.workload.layer_attributes import LayerPadding
from zigzag.workload.layer_node import LayerNode, LayerNodeAttributes
from zigzag.visualization.results.plot_cme import shorten_onnx_layer_name

OperandTensorReshape: TypeAlias = dict[LayerOperand, tuple[int, int, int, int]]
LoopRanges: TypeAlias = dict[LayerDim, tuple[int, int]]
Expand Down Expand Up @@ -120,6 +121,10 @@ def get_operand_tensor_reshape_default(self) -> OperandTensorReshape | None:
except KeyError:
return None

@property
def short_name(self) -> str:
return shorten_onnx_layer_name(self.name)

def __str__(self):
return f"ComputationNode{self.id}_{self.sub_id}"

Expand Down Expand Up @@ -216,6 +221,7 @@ def set_nb_real_predecessors(self, nb_real_predecessors: int):
self.nb_real_predecessors = nb_real_predecessors

def update_loop_ranges(self, new_ranges: LoopRanges):
"""Override the loop ranges with a new value for each of the given LayerDims. Keep the old range for the LayerDims not defined in `new_ranges`"""
"""Override the loop ranges with a new value for each of the given LayerDims. Keep the old range for the
LayerDims not defined in `new_ranges`"""
for layer_dim in new_ranges:
self.loop_ranges[layer_dim] = new_ranges[layer_dim]
2 changes: 1 addition & 1 deletion stream/visualization/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def get_dataframe_from_scme(scme, layer_ids, add_communication=False):
tensors = get_real_input_tensors(node, scme.workload)
task_type = "compute"
d = dict(
Task=str(node),
Task=node.short_name,
Start=start,
End=end,
Resource=f"Core {core_id}",
Expand Down

0 comments on commit 29b6245

Please sign in to comment.