From fe2b75a1b56850497973f7e74ddc6c139e3dfde8 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Fri, 17 Feb 2023 22:50:38 +0100 Subject: [PATCH] Add port name --- b_asic/port.py | 21 +-- b_asic/schedule.py | 3 +- b_asic/signal_flow_graph.py | 250 ++++++++++-------------------------- test/test_sfg.py | 234 ++++++++++----------------------- 4 files changed, 147 insertions(+), 361 deletions(-) diff --git a/b_asic/port.py b/b_asic/port.py index e1d35e3e..ef692473 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -49,8 +49,7 @@ class Port(ABC): @latency_offset.setter @abstractmethod def latency_offset(self, latency_offset: int) -> None: - """Set the latency_offset of the port to the integer specified value. - """ + """Set the latency_offset of the port to the integer specified value.""" raise NotImplementedError @property @@ -94,6 +93,12 @@ class Port(ABC): """Removes all connected signals from the Port.""" raise NotImplementedError + @property + @abstractmethod + def name(self) -> str: + """Return a name consisting of *graph_id* of the related operation and the port number. + """ + class AbstractPort(Port): """ @@ -134,6 +139,10 @@ class AbstractPort(Port): def latency_offset(self, latency_offset: Optional[int]): self._latency_offset = latency_offset + @property + def name(self): + return f"{self.operation.graph_id}.{self.index}" + class SignalSourceProvider(ABC): """ @@ -196,13 +205,9 @@ class InputPort(AbstractPort): Get the output port that is currently connected to this input port, or None if it is unconnected. """ - return ( - None if self._source_signal is None else self._source_signal.source - ) + return None if self._source_signal is None else self._source_signal.source - def connect( - self, src: SignalSourceProvider, name: Name = Name("") - ) -> Signal: + def connect(self, src: SignalSourceProvider, name: Name = Name("")) -> Signal: """ Connect the provided signal source to this input port by creating a new signal. Returns the new signal. diff --git a/b_asic/schedule.py b/b_asic/schedule.py index b94b9044..2b725089 100644 --- a/b_asic/schedule.py +++ b/b_asic/schedule.py @@ -577,7 +577,7 @@ class Schedule: start_time + cast(int, outport.latency_offset), outport, reads, - outport.operation.graph_id, + outport.name, ) ) return ret @@ -800,6 +800,7 @@ class Schedule: The vertical distance between operations in the schedule. The height of the operation is always 1. """ + self._plot_schedule(ax, operation_gap=operation_gap) def plot(self, operation_gap: Optional[float] = None) -> None: """ diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 43216d62..366f93cd 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -63,9 +63,7 @@ class GraphIDGenerator: """Construct a GraphIDGenerator.""" self._next_id_number = defaultdict(lambda: id_number_offset) - def next_id( - self, type_name: TypeName, used_ids: MutableSet = set() - ) -> GraphID: + def next_id(self, type_name: TypeName, used_ids: MutableSet = set()) -> GraphID: """Get the next graph id for a certain graph id type.""" self._next_id_number[type_name] += 1 new_id = type_name + str(self._next_id_number[type_name]) @@ -139,15 +137,11 @@ class SFG(AbstractOperation): output_signals: Optional[Sequence[Signal]] = None, id_number_offset: GraphIDNumber = GraphIDNumber(0), name: Name = Name(""), - input_sources: Optional[ - Sequence[Optional[SignalSourceProvider]] - ] = None, + input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None, ): input_signal_count = 0 if input_signals is None else len(input_signals) input_operation_count = 0 if inputs is None else len(inputs) - output_signal_count = ( - 0 if output_signals is None else len(output_signals) - ) + output_signal_count = 0 if output_signals is None else len(output_signals) output_operation_count = 0 if outputs is None else len(outputs) super().__init__( input_count=input_signal_count + input_operation_count, @@ -162,9 +156,7 @@ class SFG(AbstractOperation): self._components_dfs_order = [] self._operations_dfs_order = [] self._operations_topological_order = [] - self._graph_id_generator = GraphIDGenerator( - GraphIDNumber(id_number_offset) - ) + self._graph_id_generator = GraphIDGenerator(GraphIDNumber(id_number_offset)) self._input_operations = [] self._output_operations = [] self._original_components_to_new = {} @@ -176,15 +168,11 @@ class SFG(AbstractOperation): if input_signals is not None: for input_index, signal in enumerate(input_signals): if signal in self._original_components_to_new: - raise ValueError( - f"Duplicate input signal {signal!r} in SFG" - ) + raise ValueError(f"Duplicate input signal {signal!r} in SFG") new_input_op = cast( Input, self._add_component_unconnected_copy(Input()) ) - new_signal = cast( - Signal, self._add_component_unconnected_copy(signal) - ) + new_signal = cast(Signal, self._add_component_unconnected_copy(signal)) new_signal.set_source(new_input_op.output(0)) self._input_operations.append(new_input_op) self._original_input_signals_to_indices[signal] = input_index @@ -193,9 +181,7 @@ class SFG(AbstractOperation): if inputs is not None: for input_index, input_op in enumerate(inputs, input_signal_count): if input_op in self._original_components_to_new: - raise ValueError( - f"Duplicate input operation {input_op!r} in SFG" - ) + raise ValueError(f"Duplicate input operation {input_op!r} in SFG") new_input_op = cast( Input, self._add_component_unconnected_copy(input_op) ) @@ -209,9 +195,7 @@ class SFG(AbstractOperation): Signal, self._add_component_unconnected_copy(signal) ) new_signal.set_source(new_input_op.output(0)) - self._original_input_signals_to_indices[ - signal - ] = input_index + self._original_input_signals_to_indices[signal] = input_index self._input_operations.append(new_input_op) @@ -223,9 +207,7 @@ class SFG(AbstractOperation): ) if signal in self._original_components_to_new: # Signal was already added when setting up inputs. - new_signal = cast( - Signal, self._original_components_to_new[signal] - ) + new_signal = cast(Signal, self._original_components_to_new[signal]) new_signal.set_destination(new_output_op.input(0)) else: # New signal has to be created. @@ -239,13 +221,9 @@ class SFG(AbstractOperation): # Setup output operations, starting from indices after output signals. if outputs is not None: - for output_index, output_op in enumerate( - outputs, output_signal_count - ): + for output_index, output_op in enumerate(outputs, output_signal_count): if output_op in self._original_components_to_new: - raise ValueError( - f"Duplicate output operation {output_op!r} in SFG" - ) + raise ValueError(f"Duplicate output operation {output_op!r} in SFG") new_output_op = cast( Output, self._add_component_unconnected_copy(output_op) @@ -264,9 +242,7 @@ class SFG(AbstractOperation): ) new_signal.set_destination(new_output_op.input(0)) - self._original_output_signals_to_indices[ - signal - ] = output_index + self._original_output_signals_to_indices[signal] = output_index self._output_operations.append(new_output_op) @@ -282,13 +258,9 @@ class SFG(AbstractOperation): if new_signal.destination is None: if signal.destination is None: raise ValueError( - f"Input signal #{input_index} is missing destination" - " in SFG" + f"Input signal #{input_index} is missing destination in SFG" ) - if ( - signal.destination.operation - not in self._original_components_to_new - ): + if signal.destination.operation not in self._original_components_to_new: self._add_operation_connected_tree_copy( signal.destination.operation ) @@ -319,16 +291,10 @@ class SFG(AbstractOperation): if new_signal.source is None: if signal.source is None: raise ValueError( - f"Output signal #{output_index} is missing source" - " in SFG" - ) - if ( - signal.source.operation - not in self._original_components_to_new - ): - self._add_operation_connected_tree_copy( - signal.source.operation + f"Output signal #{output_index} is missing source in SFG" ) + if signal.source.operation not in self._original_components_to_new: + self._add_operation_connected_tree_copy(signal.source.operation) def __str__(self) -> str: """Return a string representation of this SFG.""" @@ -458,18 +424,14 @@ class SFG(AbstractOperation): return False # For each input_signal, connect it to the corresponding operation - for input_port, input_operation in zip( - self.inputs, self.input_operations - ): + for input_port, input_operation in zip(self.inputs, self.input_operations): destination = input_operation.output(0).signals[0].destination if destination is None: raise ValueError("Missing destination in signal.") destination.clear() input_port.signals[0].set_destination(destination) # For each output_signal, connect it to the corresponding operation - for output_port, output_operation in zip( - self.outputs, self.output_operations - ): + for output_port, output_operation in zip(self.outputs, self.output_operations): src = output_operation.input(0).signals[0].source if src is None: raise ValueError("Missing source in signal.") @@ -521,8 +483,7 @@ class SFG(AbstractOperation): input_indexes_required = [] sfg_input_operations_to_indexes = { - input_op: index - for index, input_op in enumerate(self._input_operations) + input_op: index for index, input_op in enumerate(self._input_operations) } output_op = self._output_operations[output_index] queue: Deque[Operation] = deque([output_op]) @@ -531,9 +492,7 @@ class SFG(AbstractOperation): op = queue.popleft() if isinstance(op, Input): if op in sfg_input_operations_to_indexes: - input_indexes_required.append( - sfg_input_operations_to_indexes[op] - ) + input_indexes_required.append(sfg_input_operations_to_indexes[op]) del sfg_input_operations_to_indexes[op] for input_port in op.inputs: @@ -573,9 +532,7 @@ class SFG(AbstractOperation): """Get all operations of this graph in depth-first order.""" return list(self._operations_dfs_order) - def find_by_type_name( - self, type_name: TypeName - ) -> Sequence[GraphComponent]: + def find_by_type_name(self, type_name: TypeName) -> Sequence[GraphComponent]: """ Find all components in this graph with the specified type name. Returns an empty sequence if no components were found. @@ -640,9 +597,7 @@ class SFG(AbstractOperation): keys.append(comp.key(output_index, comp.graph_id)) return keys - def replace_component( - self, component: Operation, graph_id: GraphID - ) -> "SFG": + def replace_component(self, component: Operation, graph_id: GraphID) -> "SFG": """ Find and replace all components matching either on GraphID, Type or both. Then return a new deepcopy of the sfg with the replaced component. @@ -662,13 +617,9 @@ class SFG(AbstractOperation): if component_copy is None or not isinstance(component_copy, Operation): raise ValueError("No operation matching the criteria found") if component_copy.output_count != component.output_count: - raise TypeError( - "The output count may not differ between the operations" - ) + raise TypeError("The output count may not differ between the operations") if component_copy.input_count != component.input_count: - raise TypeError( - "The input count may not differ between the operations" - ) + raise TypeError("The input count may not differ between the operations") for index_in, inp in enumerate(component_copy.inputs): for signal in inp.signals: @@ -717,8 +668,7 @@ class SFG(AbstractOperation): ) if len(output_comp.output_signals) != component.output_count: raise TypeError( - "Destination operation input count does not match output for" - " component." + "Destination operation input count does not match output for component." ) for index, signal_in in enumerate(output_comp.output_signals): @@ -789,9 +739,7 @@ class SFG(AbstractOperation): return self._precedence_list # Find all operations with only outputs and no inputs. - no_input_ops = list( - filter(lambda op: op.input_count == 0, self.operations) - ) + no_input_ops = list(filter(lambda op: op.input_count == 0, self.operations)) delay_ops = self.find_by_type_name(Delay.type_name()) # Find all first iter output ports for precedence @@ -801,9 +749,7 @@ class SFG(AbstractOperation): for i in range(op.output_count) ] - self._precedence_list = self._traverse_for_precedence_list( - first_iter_ports - ) + self._precedence_list = self._traverse_for_precedence_list(first_iter_ports) return self._precedence_list @@ -821,7 +767,7 @@ class SFG(AbstractOperation): with pg.subgraph(name=f"cluster_{i}") as sub: sub.attr(label=f"N{i}") for port in ports: - port_string = f"{port.operation.graph_id}.{port.index}" + port_string = port.name if port.operation.output_count > 1: sub.node(port_string) else: @@ -838,7 +784,7 @@ class SFG(AbstractOperation): ports = p_list[i] for port in ports: source_label = port.operation.graph_id - node_node = f"{source_label}.{port.index}" + node_node = port.name for signal in port.signals: destination = cast(InputPort, signal.destination) destination_label = destination.operation.graph_id @@ -851,9 +797,7 @@ class SFG(AbstractOperation): pg.node( destination_node, label=destination_label, - shape=_OPERATION_SHAPE[ - destination.operation.type_name() - ], + shape=_OPERATION_SHAPE[destination.operation.type_name()], ) source_node = ( source_label + "Out" @@ -914,9 +858,7 @@ class SFG(AbstractOperation): no_inputs_queue = deque( list(filter(lambda op: op.input_count == 0, self.operations)) ) - remaining_inports_per_operation = { - op: op.input_count for op in self.operations - } + remaining_inports_per_operation = {op: op.input_count for op in self.operations} # Maps number of input counts to a queue of seen objects with such a size. seen_with_inputs_dict: Dict[int, Deque] = defaultdict(deque) @@ -931,9 +873,7 @@ class SFG(AbstractOperation): p_queue = PriorityQueue() p_queue_entry_num = it.count() # Negative priority as max-heap popping is wanted - p_queue.put( - (-first_op.output_count, -next(p_queue_entry_num), first_op) - ) + p_queue.put((-first_op.output_count, -next(p_queue_entry_num), first_op)) operations_left = len(self.operations) - 1 @@ -950,9 +890,7 @@ class SFG(AbstractOperation): for neighbor_op in op.subsequent_operations: if neighbor_op not in visited: remaining_inports_per_operation[neighbor_op] -= 1 - remaining_inports = remaining_inports_per_operation[ - neighbor_op - ] + remaining_inports = remaining_inports_per_operation[neighbor_op] if remaining_inports == 0: p_queue.put( @@ -965,16 +903,14 @@ class SFG(AbstractOperation): elif remaining_inports > 0: if neighbor_op in seen: - seen_with_inputs_dict[ - remaining_inports + 1 - ].remove(neighbor_op) + seen_with_inputs_dict[remaining_inports + 1].remove( + neighbor_op + ) else: seen.add(neighbor_op) seen_but_not_visited_count += 1 - seen_with_inputs_dict[remaining_inports].append( - neighbor_op - ) + seen_with_inputs_dict[remaining_inports].append(neighbor_op) # Check if have to fetch Operations from somewhere else since p_queue # is empty @@ -1065,9 +1001,7 @@ class SFG(AbstractOperation): self, first_iter_ports: List[OutputPort] ) -> List[List[OutputPort]]: # Find dependencies of output ports and input ports. - remaining_inports_per_operation = { - op: op.input_count for op in self.operations - } + remaining_inports_per_operation = {op: op.input_count for op in self.operations} # Traverse output ports for precedence curr_iter_ports = first_iter_ports @@ -1102,10 +1036,7 @@ class SFG(AbstractOperation): raise ValueError("Tried to add duplicate SFG component") new_component = original_component.copy_component() self._original_components_to_new[original_component] = new_component - if ( - not new_component.graph_id - or new_component.graph_id in self._used_ids - ): + if not new_component.graph_id or new_component.graph_id in self._used_ids: new_id = self._graph_id_generator.next_id( new_component.type_name(), self._used_ids ) @@ -1128,9 +1059,7 @@ class SFG(AbstractOperation): self._components_dfs_order.append(new_op) self._operations_dfs_order.append(new_op) else: - new_op = cast( - Operation, self._original_components_to_new[original_op] - ) + new_op = cast(Operation, self._original_components_to_new[original_op]) # Connect input ports to new signals. for original_input_port in original_op.inputs: @@ -1139,10 +1068,7 @@ class SFG(AbstractOperation): for original_signal in original_input_port.signals: # Check if the signal is one of the SFG's input signals. - if ( - original_signal - in self._original_input_signals_to_indices - ): + if original_signal in self._original_input_signals_to_indices: # New signal already created during first step of constructor. new_signal = cast( Signal, @@ -1161,9 +1087,7 @@ class SFG(AbstractOperation): self._operations_dfs_order.append(source.operation) # Check if the signal has not been added before. - elif ( - original_signal not in self._original_components_to_new - ): + elif original_signal not in self._original_components_to_new: if original_signal.source is None: dest = ( original_signal.destination.operation.name @@ -1177,9 +1101,7 @@ class SFG(AbstractOperation): new_signal = cast( Signal, - self._add_component_unconnected_copy( - original_signal - ), + self._add_component_unconnected_copy(original_signal), ) new_signal.set_destination( @@ -1188,19 +1110,12 @@ class SFG(AbstractOperation): self._components_dfs_order.append(new_signal) - original_connected_op = ( - original_signal.source.operation - ) + original_connected_op = original_signal.source.operation # Check if connected Operation has been added before. - if ( - original_connected_op - in self._original_components_to_new - ): + if original_connected_op in self._original_components_to_new: component = cast( Operation, - self._original_components_to_new[ - original_connected_op - ], + self._original_components_to_new[original_connected_op], ) # Set source to the already added operations port. new_signal.set_source( @@ -1215,9 +1130,7 @@ class SFG(AbstractOperation): ), ) new_signal.set_source( - new_connected_op.output( - original_signal.source.index - ) + new_connected_op.output(original_signal.source.index) ) self._components_dfs_order.append(new_connected_op) @@ -1230,32 +1143,23 @@ class SFG(AbstractOperation): for original_output_port in original_op.outputs: for original_signal in original_output_port.signals: # Check if the signal is one of the SFG's output signals. - if ( - original_signal - in self._original_output_signals_to_indices - ): + if original_signal in self._original_output_signals_to_indices: # New signal already created during first step of constructor. new_signal = cast( Signal, self._original_components_to_new[original_signal], ) - new_signal.set_source( - new_op.output(original_output_port.index) - ) + new_signal.set_source(new_op.output(original_output_port.index)) destination = cast(InputPort, new_signal.destination) self._components_dfs_order.extend( [new_signal, destination.operation] ) - self._operations_dfs_order.append( - destination.operation - ) + self._operations_dfs_order.append(destination.operation) # Check if signal has not been added before. - elif ( - original_signal not in self._original_components_to_new - ): + elif original_signal not in self._original_components_to_new: if original_signal.source is None: raise ValueError( "Dangling signal ({original_signal}) without" @@ -1264,13 +1168,9 @@ class SFG(AbstractOperation): new_signal = cast( Signal, - self._add_component_unconnected_copy( - original_signal - ), - ) - new_signal.set_source( - new_op.output(original_output_port.index) + self._add_component_unconnected_copy(original_signal), ) + new_signal.set_source(new_op.output(original_output_port.index)) self._components_dfs_order.append(new_signal) original_destination = cast( @@ -1278,8 +1178,7 @@ class SFG(AbstractOperation): ) if original_destination is None: raise ValueError( - f"Signal ({original_signal}) without" - " destination in SFG" + f"Signal ({original_signal}) without destination in SFG" ) original_connected_op = original_destination.operation @@ -1289,10 +1188,7 @@ class SFG(AbstractOperation): f" ({original_destination}) in SFG" ) # Check if connected operation has been added. - if ( - original_connected_op - in self._original_components_to_new - ): + if original_connected_op in self._original_components_to_new: # Set destination to the already connected operations port. new_signal.set_destination( cast( @@ -1313,9 +1209,7 @@ class SFG(AbstractOperation): ), ) new_signal.set_destination( - new_connected_op.input( - original_destination.index - ) + new_connected_op.input(original_destination.index) ) self._components_dfs_order.append(new_connected_op) @@ -1447,9 +1341,7 @@ class SFG(AbstractOperation): return dg def _repr_mimebundle_(self, include=None, exclude=None): - return self.sfg_digraph()._repr_mimebundle_( - include=include, exclude=exclude - ) + return self.sfg_digraph()._repr_mimebundle_(include=include, exclude=exclude) def _repr_jpeg_(self): return self.sfg_digraph()._repr_mimebundle_(include=["image/jpeg"])[ @@ -1457,9 +1349,7 @@ class SFG(AbstractOperation): ] def _repr_png_(self): - return self.sfg_digraph()._repr_mimebundle_(include=["image/png"])[ - "image/png" - ] + return self.sfg_digraph()._repr_mimebundle_(include=["image/png"])["image/png"] def show(self, fmt=None, show_id=False, engine=None) -> None: """ @@ -1517,9 +1407,7 @@ class SFG(AbstractOperation): for _ in range(factor) ] - id_idx_map = { - op.graph_id: idx for (idx, op) in enumerate(self.operations) - } + id_idx_map = {op.graph_id: idx for (idx, op) in enumerate(self.operations)} # The rest of the process is easier if we clear the connections of the inputs # and outputs of all operations @@ -1532,9 +1420,7 @@ class SFG(AbstractOperation): suffix = layer - new_ops[layer][ - op_idx - ].name = f"{new_ops[layer][op_idx].name}_{suffix}" + new_ops[layer][op_idx].name = f"{new_ops[layer][op_idx].name}_{suffix}" # NOTE: Since these IDs are what show up when printing the graph, it # is helpful to set them. However, this can cause name collisions when # names in a graph are already suffixed with _n @@ -1554,9 +1440,7 @@ class SFG(AbstractOperation): source_op_idx = id_idx_map[source_port.operation.graph_id] source_op_output_index = source_port.index new_source_op = new_ops[layer][source_op_idx] - source_op_output = new_source_op.outputs[ - source_op_output_index - ] + source_op_output = new_source_op.outputs[source_op_output_index] # If this is the last layer, we need to create a new delay element and connect it instead # of the copied port @@ -1586,9 +1470,7 @@ class SFG(AbstractOperation): target_layer = 0 if layer == factor - 1 else layer + 1 new_dest_op = new_ops[target_layer][sink_op_idx] - new_destination = new_dest_op.inputs[ - sink_op_output_index - ] + new_destination = new_dest_op.inputs[sink_op_output_index] new_destination.connect(new_source_port) else: # Other opreations need to be re-targeted to the corresponding output in the @@ -1609,9 +1491,9 @@ class SFG(AbstractOperation): ] source_op_output_idx = original_source.index - target_output = new_ops[layer][ - source_op_idx - ].outputs[source_op_output_idx] + target_output = new_ops[layer][source_op_idx].outputs[ + source_op_output_idx + ] new_ops[layer][op_idx].inputs[input_num].connect( target_output diff --git a/test/test_sfg.py b/test/test_sfg.py index 601df109..6ece1458 100644 --- a/test/test_sfg.py +++ b/test/test_sfg.py @@ -85,8 +85,7 @@ class TestPrintSfg: assert ( sfg.__str__() - == "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0:" - " '-'}\n" + == "id: no_id, \tname: SFG1, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" + "----------------------------------------------------------------------------------------------------\n" + str(sfg.find_by_name("INP1")[0]) @@ -111,8 +110,7 @@ class TestPrintSfg: assert ( sfg.__str__() - == "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0:" - " '-'}\n" + == "id: no_id, \tname: mac_sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" + "----------------------------------------------------------------------------------------------------\n" + str(sfg.find_by_name("INP1")[0]) @@ -140,8 +138,7 @@ class TestPrintSfg: assert ( sfg.__str__() - == "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0:" - " '-'}\n" + == "id: no_id, \tname: sfg, \tinputs: {0: '-'}, \toutputs: {0: '-'}\n" + "Internal Operations:\n" + "----------------------------------------------------------------------------------------------------\n" + str(sfg.find_by_name("CONST")[0]) @@ -257,9 +254,7 @@ class TestEvaluateOutput: def test_evaluate_output_cycle(self, operation_graph_with_cycle): sfg = SFG(outputs=[Output(operation_graph_with_cycle)]) - with pytest.raises( - RuntimeError, match="Direct feedback loop detected" - ): + with pytest.raises(RuntimeError, match="Direct feedback loop detected"): sfg.evaluate_output(0, []) @@ -306,9 +301,7 @@ class TestReplaceComponents: sfg = SFG(outputs=[Output(operation_tree)]) component_id = "add1" - sfg = sfg.replace_component( - Multiplication(name="Multi"), graph_id=component_id - ) + sfg = sfg.replace_component(Multiplication(name="Multi"), graph_id=component_id) assert component_id not in sfg._components_by_id.keys() assert "Multi" in sfg._components_by_name.keys() @@ -316,9 +309,7 @@ class TestReplaceComponents: sfg = SFG(outputs=[Output(large_operation_tree)]) component_id = "add3" - sfg = sfg.replace_component( - Multiplication(name="Multi"), graph_id=component_id - ) + sfg = sfg.replace_component(Multiplication(name="Multi"), graph_id=component_id) assert "Multi" in sfg._components_by_name.keys() assert component_id not in sfg._components_by_id.keys() @@ -381,9 +372,7 @@ class TestConstructSFG: for _ in range(499): prev_op_sub = Subtraction(prev_op_sub, Constant(2)) butterfly = Butterfly(prev_op_add, prev_op_sub) - sfg = SFG( - outputs=[Output(butterfly.output(0)), Output(butterfly.output(1))] - ) + sfg = SFG(outputs=[Output(butterfly.output(0)), Output(butterfly.output(1))]) sim = FastSimulation(sfg) sim.step() assert sim.results["0"][0].real == 0 @@ -467,21 +456,14 @@ class TestInsertComponent: sfg = SFG(outputs=[Output(large_operation_tree_names)]) sqrt = SquareRoot() - _sfg = sfg.insert_operation( - sqrt, sfg.find_by_name("constant4")[0].graph_id - ) + _sfg = sfg.insert_operation(sqrt, sfg.find_by_name("constant4")[0].graph_id) assert _sfg.evaluate() != sfg.evaluate() assert any([isinstance(comp, SquareRoot) for comp in _sfg.operations]) - assert not any( - [isinstance(comp, SquareRoot) for comp in sfg.operations] - ) + assert not any([isinstance(comp, SquareRoot) for comp in sfg.operations]) assert not isinstance( - sfg.find_by_name("constant4")[0] - .output(0) - .signals[0] - .destination.operation, + sfg.find_by_name("constant4")[0].output(0).signals[0].destination.operation, SquareRoot, ) assert isinstance( @@ -529,17 +511,11 @@ class TestInsertComponent: # Correctly connected old output -> new input assert ( - _sfg.find_by_name("bfly3")[0] - .output(0) - .signals[0] - .destination.operation + _sfg.find_by_name("bfly3")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] ) assert ( - _sfg.find_by_name("bfly3")[0] - .output(1) - .signals[0] - .destination.operation + _sfg.find_by_name("bfly3")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("n_bfly")[0] ) @@ -555,17 +531,11 @@ class TestInsertComponent: # Correctly connected new output -> next input assert ( - _sfg.find_by_name("n_bfly")[0] - .output(0) - .signals[0] - .destination.operation + _sfg.find_by_name("n_bfly")[0].output(0).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] ) assert ( - _sfg.find_by_name("n_bfly")[0] - .output(1) - .signals[0] - .destination.operation + _sfg.find_by_name("n_bfly")[0].output(1).signals[0].destination.operation is _sfg.find_by_name("bfly2")[0] ) @@ -600,28 +570,24 @@ class TestFindComponentsWithTypeName: mac_sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="mac_sfg") - assert { - comp.name for comp in mac_sfg.find_by_type_name(inp1.type_name()) - } == { + assert {comp.name for comp in mac_sfg.find_by_type_name(inp1.type_name())} == { "INP1", "INP2", "INP3", } - assert { - comp.name for comp in mac_sfg.find_by_type_name(add1.type_name()) - } == { + assert {comp.name for comp in mac_sfg.find_by_type_name(add1.type_name())} == { "ADD1", "ADD2", } - assert { - comp.name for comp in mac_sfg.find_by_type_name(mul1.type_name()) - } == {"MUL1"} + assert {comp.name for comp in mac_sfg.find_by_type_name(mul1.type_name())} == { + "MUL1" + } - assert { - comp.name for comp in mac_sfg.find_by_type_name(out1.type_name()) - } == {"OUT1"} + assert {comp.name for comp in mac_sfg.find_by_type_name(out1.type_name())} == { + "OUT1" + } assert { comp.name for comp in mac_sfg.find_by_type_name(Signal.type_name()) @@ -697,9 +663,7 @@ class TestGetPrecedenceList: def test_inputs_constants_delays_multiple_outputs( self, precedence_sfg_delays_and_constants ): - precedence_list = ( - precedence_sfg_delays_and_constants.get_precedence_list() - ) + precedence_list = precedence_sfg_delays_and_constants.get_precedence_list() assert len(precedence_list) == 7 @@ -914,25 +878,15 @@ class TestPrintPrecedence: class TestDepends: def test_depends_sfg(self, sfg_two_inputs_two_outputs): - assert set( - sfg_two_inputs_two_outputs.inputs_required_for_output(0) - ) == {0, 1} - assert set( - sfg_two_inputs_two_outputs.inputs_required_for_output(1) - ) == {0, 1} + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(0)) == {0, 1} + assert set(sfg_two_inputs_two_outputs.inputs_required_for_output(1)) == {0, 1} - def test_depends_sfg_independent( - self, sfg_two_inputs_two_outputs_independent - ): + def test_depends_sfg_independent(self, sfg_two_inputs_two_outputs_independent): assert set( - sfg_two_inputs_two_outputs_independent.inputs_required_for_output( - 0 - ) + sfg_two_inputs_two_outputs_independent.inputs_required_for_output(0) ) == {0} assert set( - sfg_two_inputs_two_outputs_independent.inputs_required_for_output( - 1 - ) + sfg_two_inputs_two_outputs_independent.inputs_required_for_output(1) ) == {1} @@ -1051,8 +1005,7 @@ class TestConnectExternalSignalsToComponentsMultipleComp: assert not test_sfg.connect_external_signals_to_components() def create_sfg(self, op_tree): - """Create a simple SFG with either operation_tree or large_operation_tree - """ + """Create a simple SFG with either operation_tree or large_operation_tree""" sfg1 = SFG(outputs=[Output(op_tree)]) inp1 = Input("INP1") @@ -1069,9 +1022,7 @@ class TestConnectExternalSignalsToComponentsMultipleComp: return SFG(inputs=[inp1, inp2], outputs=[out1]) - def test_connect_external_signals_to_components_many_op( - self, large_operation_tree - ): + def test_connect_external_signals_to_components_many_op(self, large_operation_tree): """Replaces an sfg component in a larger SFG with several component operations """ inp1 = Input("INP1") @@ -1103,9 +1054,7 @@ class TestConnectExternalSignalsToComponentsMultipleComp: class TestTopologicalOrderOperations: def test_feedback_sfg(self, sfg_simple_filter): - topological_order = ( - sfg_simple_filter.get_operations_topological_order() - ) + topological_order = sfg_simple_filter.get_operations_topological_order() assert [comp.name for comp in topological_order] == [ "IN1", @@ -1115,9 +1064,7 @@ class TestTopologicalOrderOperations: "OUT1", ] - def test_multiple_independent_inputs( - self, sfg_two_inputs_two_outputs_independent - ): + def test_multiple_independent_inputs(self, sfg_two_inputs_two_outputs_independent): topological_order = ( sfg_two_inputs_two_outputs_independent.get_operations_topological_order() ) @@ -1132,9 +1079,7 @@ class TestTopologicalOrderOperations: ] def test_complex_graph(self, precedence_sfg_delays): - topological_order = ( - precedence_sfg_delays.get_operations_topological_order() - ) + topological_order = precedence_sfg_delays.get_operations_topological_order() assert [comp.name for comp in topological_order] == [ "IN1", @@ -1161,39 +1106,28 @@ class TestRemove: assert set( op.name - for op in sfg_simple_filter.find_by_name("T1")[ - 0 - ].subsequent_operations + for op in sfg_simple_filter.find_by_name("T1")[0].subsequent_operations ) == {"CMUL1", "OUT1"} assert set( - op.name - for op in new_sfg.find_by_name("T1")[0].subsequent_operations + op.name for op in new_sfg.find_by_name("T1")[0].subsequent_operations ) == {"ADD1", "OUT1"} assert set( op.name - for op in sfg_simple_filter.find_by_name("ADD1")[ - 0 - ].preceding_operations + for op in sfg_simple_filter.find_by_name("ADD1")[0].preceding_operations ) == {"CMUL1", "IN1"} assert set( - op.name - for op in new_sfg.find_by_name("ADD1")[0].preceding_operations + op.name for op in new_sfg.find_by_name("ADD1")[0].preceding_operations ) == {"T1", "IN1"} assert "S1" in set( [ sig.name - for sig in sfg_simple_filter.find_by_name("T1")[0] - .output(0) - .signals + for sig in sfg_simple_filter.find_by_name("T1")[0].output(0).signals ] ) assert "S2" in set( - [ - sig.name - for sig in new_sfg.find_by_name("T1")[0].output(0).signals - ] + [sig.name for sig in new_sfg.find_by_name("T1")[0].output(0).signals] ) def test_remove_multiple_inputs_outputs(self, butterfly_operation_tree): @@ -1207,9 +1141,7 @@ class TestRemove: assert sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 assert new_sfg.find_by_name("bfly3")[0].output(0).signal_count == 1 - sfg_dest_0 = ( - sfg.find_by_name("bfly3")[0].output(0).signals[0].destination - ) + sfg_dest_0 = sfg.find_by_name("bfly3")[0].output(0).signals[0].destination new_sfg_dest_0 = ( new_sfg.find_by_name("bfly3")[0].output(0).signals[0].destination ) @@ -1222,9 +1154,7 @@ class TestRemove: assert sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 assert new_sfg.find_by_name("bfly3")[0].output(1).signal_count == 1 - sfg_dest_1 = ( - sfg.find_by_name("bfly3")[0].output(1).signals[0].destination - ) + sfg_dest_1 = sfg.find_by_name("bfly3")[0].output(1).signals[0].destination new_sfg_dest_1 = ( new_sfg.find_by_name("bfly3")[0].output(1).signals[0].destination ) @@ -1238,9 +1168,7 @@ class TestRemove: assert new_sfg.find_by_name("bfly1")[0].input(0).signal_count == 1 sfg_source_0 = sfg.find_by_name("bfly1")[0].input(0).signals[0].source - new_sfg_source_0 = ( - new_sfg.find_by_name("bfly1")[0].input(0).signals[0].source - ) + new_sfg_source_0 = new_sfg.find_by_name("bfly1")[0].input(0).signals[0].source assert sfg_source_0.index == 0 assert new_sfg_source_0.index == 0 @@ -1248,9 +1176,7 @@ class TestRemove: assert new_sfg_source_0.operation.name == "bfly3" sfg_source_1 = sfg.find_by_name("bfly1")[0].input(1).signals[0].source - new_sfg_source_1 = ( - new_sfg.find_by_name("bfly1")[0].input(1).signals[0].source - ) + new_sfg_source_1 = new_sfg.find_by_name("bfly1")[0].input(1).signals[0].source assert sfg_source_1.index == 1 assert new_sfg_source_1.index == 1 @@ -1268,9 +1194,7 @@ class TestSaveLoadSFG: def get_path(self, existing=False): path_ = "".join(random.choices(string.ascii_uppercase, k=4)) + ".py" while path.exists(path_) if not existing else not path.exists(path_): - path_ = ( - "".join(random.choices(string.ascii_uppercase, k=4)) + ".py" - ) + path_ = "".join(random.choices(string.ascii_uppercase, k=4)) + ".py" return path_ @@ -1357,46 +1281,36 @@ class TestGetComponentsOfType: def test_get_multple_operations_of_type(self, sfg_two_inputs_two_outputs): assert [ op.name - for op in sfg_two_inputs_two_outputs.find_by_type_name( - Addition.type_name() - ) + for op in sfg_two_inputs_two_outputs.find_by_type_name(Addition.type_name()) ] == ["ADD1", "ADD2"] assert [ op.name - for op in sfg_two_inputs_two_outputs.find_by_type_name( - Input.type_name() - ) + for op in sfg_two_inputs_two_outputs.find_by_type_name(Input.type_name()) ] == ["IN1", "IN2"] assert [ op.name - for op in sfg_two_inputs_two_outputs.find_by_type_name( - Output.type_name() - ) + for op in sfg_two_inputs_two_outputs.find_by_type_name(Output.type_name()) ] == ["OUT1", "OUT2"] class TestPrecedenceGraph: def test_precedence_graph(self, sfg_simple_filter): res = ( - 'digraph {\n\trankdir=LR\n\tsubgraph cluster_0' - ' {\n\t\tlabel=N0\n\t\t"in1.0" [label=in1 height=0.1' - ' shape=rectangle width=0.1]\n\t\t"t1.0" [label=t1 height=0.1' - ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_1' - ' {\n\t\tlabel=N1\n\t\t"cmul1.0" [label=cmul1 height=0.1' - ' shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_2' - ' {\n\t\tlabel=N2\n\t\t"add1.0" [label=add1 height=0.1' - ' shape=rectangle width=0.1]\n\t}\n\t"in1.0" -> add1\n\tadd1' - ' [label=add1 shape=ellipse]\n\tin1 -> "in1.0"\n\tin1 [label=in1' - ' shape=cds]\n\t"t1.0" -> cmul1\n\tcmul1 [label=cmul1' - ' shape=ellipse]\n\t"t1.0" -> out1\n\tout1 [label=out1' - ' shape=cds]\n\tt1Out -> "t1.0"\n\tt1Out [label=t1' - ' shape=square]\n\t"cmul1.0" -> add1\n\tadd1 [label=add1' - ' shape=ellipse]\n\tcmul1 -> "cmul1.0"\n\tcmul1 [label=cmul1' + 'digraph {\n\trankdir=LR\n\tsubgraph cluster_0 {\n\t\tlabel=N0\n\t\t"in1.0"' + ' [label=in1 height=0.1 shape=rectangle width=0.1]\n\t\t"t1.0" [label=t1' + ' height=0.1 shape=rectangle width=0.1]\n\t}\n\tsubgraph cluster_1' + ' {\n\t\tlabel=N1\n\t\t"cmul1.0" [label=cmul1 height=0.1 shape=rectangle' + ' width=0.1]\n\t}\n\tsubgraph cluster_2 {\n\t\tlabel=N2\n\t\t"add1.0"' + ' [label=add1 height=0.1 shape=rectangle width=0.1]\n\t}\n\t"in1.0" ->' + ' add1\n\tadd1 [label=add1 shape=ellipse]\n\tin1 -> "in1.0"\n\tin1' + ' [label=in1 shape=cds]\n\t"t1.0" -> cmul1\n\tcmul1 [label=cmul1' + ' shape=ellipse]\n\t"t1.0" -> out1\n\tout1 [label=out1 shape=cds]\n\tt1Out' + ' -> "t1.0"\n\tt1Out [label=t1 shape=square]\n\t"cmul1.0" -> add1\n\tadd1' + ' [label=add1 shape=ellipse]\n\tcmul1 -> "cmul1.0"\n\tcmul1 [label=cmul1' ' shape=ellipse]\n\t"add1.0" -> t1In\n\tt1In [label=t1' - ' shape=square]\n\tadd1 -> "add1.0"\n\tadd1 [label=add1' - ' shape=ellipse]\n}' + ' shape=square]\n\tadd1 -> "add1.0"\n\tadd1 [label=add1 shape=ellipse]\n}' ) assert sfg_simple_filter.precedence_graph().source in (res, res + "\n") @@ -1536,9 +1450,7 @@ class TestSFGErrors: adaptor = SymmetricTwoportAdaptor(0.5, in1, signal) out1 = Output(adaptor.output(0)) out2 = Output(adaptor.output(1)) - with pytest.raises( - ValueError, match="Dangling signal without source in SFG" - ): + with pytest.raises(ValueError, match="Dangling signal without source in SFG"): SFG([in1], [out1, out2]) def test_remove_signal_with_different_number_of_inputs_and_outputs(self): @@ -1552,9 +1464,7 @@ class TestSFGErrors: assert sfg1 is None with pytest.raises( ValueError, - match=( - "Different number of input and output ports of operation with" - ), + match="Different number of input and output ports of operation with", ): sfg.remove_operation('add1') @@ -1586,10 +1496,7 @@ class TestInputDuplicationBug: twotapfir = SFG(inputs=[in1], outputs=[out1], name='twotapfir') - assert ( - len([op for op in twotapfir.operations if isinstance(op, Input)]) - == 1 - ) + assert len([op for op in twotapfir.operations if isinstance(op, Input)]) == 1 class TestCriticalPath: @@ -1649,9 +1556,7 @@ class TestUnfold: double_unfolded = sfg.unfold(factor).unfold(factor) - self.assert_counts_is_correct( - sfg, double_unfolded, factor * factor - ) + self.assert_counts_is_correct(sfg, double_unfolded, factor * factor) NUM_TESTS = 5 # Evaluate with some random values @@ -1667,9 +1572,7 @@ class TestUnfold: ref = sim.results # We have i copies of the inputs, each sourcing their input from the orig - unfolded_input_lists = [ - [] for _ in range(len(sfg.inputs) * factor) - ] + unfolded_input_lists = [[] for _ in range(len(sfg.inputs) * factor)] for t in range(0, NUM_TESTS): for n in range(0, factor): for k in range(0, len(sfg.inputs)): @@ -1689,10 +1592,7 @@ class TestUnfold: # indicies where we find the outputs out_indices = [n + k * len(sfg.outputs) for k in range(factor)] u_values = [ - [ - unfolded_results[ResultKey(f"{idx}")][k] - for idx in out_indices - ] + [unfolded_results[ResultKey(f"{idx}")][k] for idx in out_indices] for k in range(int(NUM_TESTS)) ] @@ -1702,7 +1602,5 @@ class TestUnfold: def test_value_error(self, sfg_two_inputs_two_outputs: SFG): sfg = sfg_two_inputs_two_outputs - with pytest.raises( - ValueError, match="Unfolding 0 times removes the SFG" - ): + with pytest.raises(ValueError, match="Unfolding 0 times removes the SFG"): sfg.unfold(0) -- GitLab