Skip to content
Snippets Groups Projects
Commit 7f7152e5 authored by Simon Bjurek's avatar Simon Bjurek Committed by Oscar Gustafsson
Browse files

Added type argument to print_slacks_type and added tie-breaker for sort-y-times

parent 7051e74f
No related branches found
No related tags found
1 merge request!504Added type argument to print_slacks_type and added tie-breaker for sort-y-times
Pipeline #159514 passed
......@@ -410,7 +410,7 @@ class Schedule:
raise ValueError(f"No operation with graph_id {graph_id!r} in schedule")
return self.backward_slack(graph_id), self.forward_slack(graph_id)
def print_slacks(self, order: int = 0) -> None:
def print_slacks(self, order: int = 0, type_name: TypeName | None = None) -> None:
"""
Print the slack times for all operations in the schedule.
......@@ -422,14 +422,24 @@ class Schedule:
* 0: alphabetical on Graph ID
* 1: backward slack
* 2: forward slack
type_name : TypeName, optional
If given, only the slack times for operations of this type will be printed.
"""
if type_name is None:
operations = self._sfg.operations
else:
operations = [
op for op in self._sfg.operations if isinstance(op, type_name)
]
res: list[tuple[GraphID, int, int]] = [
(
op.graph_id,
cast(int, self.backward_slack(op.graph_id)),
self.forward_slack(op.graph_id),
)
for op in self._sfg.operations
for op in operations
]
res.sort(key=lambda tup: tup[order])
res_str = [
......@@ -1084,10 +1094,17 @@ class Schedule:
move_y_location
set_y_location
"""
for i, graph_id in enumerate(
sorted(self._start_times, key=self._start_times.get)
):
def sort_key(graph_id):
op = self._sfg.find_by_id(graph_id)
return (
self._start_times[op.graph_id],
-self._sfg.find_by_id(graph_id).latency,
)
for i, graph_id in enumerate(sorted(self._start_times, key=sort_key)):
self.set_y_location(graph_id, i)
for graph_id in self._start_times:
op = cast(Operation, self._sfg.find_by_id(graph_id))
# Position Outputs and Sinks adjacent to the operation generating them
......
......@@ -404,6 +404,28 @@ class TestSlacks:
)
assert captured.err == ""
def test_print_slacks_type_name_given(self, capsys, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type_name(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type_name(
ConstantMultiplication.type_name(), 3
)
schedule = Schedule(precedence_sfg_delays, scheduler=ASAPScheduler())
schedule.print_slacks(1, type_name=ConstantMultiplication)
captured = capsys.readouterr()
assert captured.out == (
"Graph ID | Backward | Forward\n"
"---------|----------|---------\n"
"cmul0 | 0 | 1\n"
"cmul1 | 0 | 0\n"
"cmul2 | 0 | 0\n"
"cmul3 | 4 | 0\n"
"cmul6 | 4 | 0\n"
"cmul4 | 16 | 0\n"
"cmul5 | 16 | 0\n"
)
assert captured.err == ""
def test_slacks_errors(self, precedence_sfg_delays):
precedence_sfg_delays.set_latency_of_type_name(Addition.type_name(), 1)
precedence_sfg_delays.set_latency_of_type_name(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment