diff --git a/b_asic/schedule.py b/b_asic/schedule.py index 8fc13eba0a5c98068bd007f5a5e1c4b48f582d9d..0ce9e42f634c42c5d935f320730a99745c2caf56 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -410,7 +410,7 @@ class Schedule: raise ValueError(f"No operation with graph_id {graph_id!r} in schedule") return self.backward_slack(graph_id), self.forward_slack(graph_id) - def print_slacks(self, order: int = 0) -> None: + def print_slacks(self, order: int = 0, type_name: TypeName | None = None) -> None: """ Print the slack times for all operations in the schedule. @@ -422,14 +422,24 @@ class Schedule: * 0: alphabetical on Graph ID * 1: backward slack * 2: forward slack + + type_name : TypeName, optional + If given, only the slack times for operations of this type will be printed. """ + if type_name is None: + operations = self._sfg.operations + else: + operations = [ + op for op in self._sfg.operations if isinstance(op, type_name) + ] + res: list[tuple[GraphID, int, int]] = [ ( op.graph_id, cast(int, self.backward_slack(op.graph_id)), self.forward_slack(op.graph_id), ) - for op in self._sfg.operations + for op in operations ] res.sort(key=lambda tup: tup[order]) res_str = [ @@ -1084,10 +1094,17 @@ class Schedule: move_y_location set_y_location """ - for i, graph_id in enumerate( - sorted(self._start_times, key=self._start_times.get) - ): + + def sort_key(graph_id): + op = self._sfg.find_by_id(graph_id) + return ( + self._start_times[op.graph_id], + -self._sfg.find_by_id(graph_id).latency, + ) + + for i, graph_id in enumerate(sorted(self._start_times, key=sort_key)): self.set_y_location(graph_id, i) + for graph_id in self._start_times: op = cast(Operation, self._sfg.find_by_id(graph_id)) # Position Outputs and Sinks adjacent to the operation generating them diff --git a/test/unit/test_schedule.py b/test/unit/test_schedule.py index bc2071f62c4c23d2075206ac2f27582813fee0fd..0a94212e7260d3369e4b671a1551579003a23ead 100644 --- a/test/unit/test_schedule.py +++ b/test/unit/test_schedule.py @@ -404,6 +404,28 @@ class TestSlacks: ) assert captured.err == "" + def test_print_slacks_type_name_given(self, capsys, precedence_sfg_delays): + precedence_sfg_delays.set_latency_of_type_name(Addition.type_name(), 1) + precedence_sfg_delays.set_latency_of_type_name( + ConstantMultiplication.type_name(), 3 + ) + + schedule = Schedule(precedence_sfg_delays, scheduler=ASAPScheduler()) + schedule.print_slacks(1, type_name=ConstantMultiplication) + captured = capsys.readouterr() + assert captured.out == ( + "Graph ID | Backward | Forward\n" + "---------|----------|---------\n" + "cmul0 | 0 | 1\n" + "cmul1 | 0 | 0\n" + "cmul2 | 0 | 0\n" + "cmul3 | 4 | 0\n" + "cmul6 | 4 | 0\n" + "cmul4 | 16 | 0\n" + "cmul5 | 16 | 0\n" + ) + assert captured.err == "" + def test_slacks_errors(self, precedence_sfg_delays): precedence_sfg_delays.set_latency_of_type_name(Addition.type_name(), 1) precedence_sfg_delays.set_latency_of_type_name(