diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index 8e3f7ab9b5e0afa339cbc7a251b1a11fe51a4d0c..1e2f3125d39933aac3cc8d2bbd3786da6cb25190 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -40,6 +40,7 @@ class SFG(AbstractOperation): _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] + _components_in_dfs_order: List[GraphComponent] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] @@ -132,25 +133,36 @@ class SFG(AbstractOperation): for sig, input_index in self._original_input_signals_to_indexes.items(): # Check if already added destination. new_sig = self._original_components_to_new[sig] - if new_sig.destination is not None and new_sig.destination.operation in output_operations_set: - # Add directly connected input to output to dfs order list - self._components_in_dfs_order.extend([new_sig.source.operation, new_sig, new_sig.destination.operation]) - elif sig.destination is None: - raise ValueError(f"Input signal #{input_index} is missing destination in SFG") - elif sig.destination.operation not in self._original_components_to_new: - self._copy_structure_from_operation_dfs( - sig.destination.operation) + if new_sig.destination is None: + if sig.destination is None: + raise ValueError( + f"Input signal #{input_index} is missing destination in SFG") + elif sig.destination.operation not in self._original_components_to_new: + self._copy_structure_from_operation_dfs( + sig.destination.operation) + else: + if new_sig.destination.operation in output_operations_set: + # Add directly connected input to output to dfs order list + self._components_in_dfs_order.extend([new_sig.source.operation, new_sig, new_sig.destination.operation]) # Search the graph inwards from each output signal. for sig, output_index in self._original_output_signals_to_indexes.items(): # Check if already added source. - mew_sig = self._original_components_to_new[sig] + new_sig = self._original_components_to_new[sig] if new_sig.source is None: if sig.source is None: raise ValueError(f"Output signal #{output_index} is missing source in SFG") if sig.source.operation not in self._original_components_to_new: self._copy_structure_from_operation_dfs(sig.source.operation) + def __call__(self, *src: Optional[SignalSourceProvider], name: Name = "") -> "SFG": + """Get a new independent SFG instance that is identical to this SFG except without any of its external connections.""" + input_sources = src + if not input_sources: + input_sources = None + return SFG(inputs = self._input_operations, outputs = self._output_operations, + id_number_offset = self._graph_id_generator.id_number_offset, name = name, input_sources = input_sources) + @property def type_name(self) -> TypeName: return "sfg" @@ -185,6 +197,11 @@ class SFG(AbstractOperation): def split(self) -> Iterable[Operation]: return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values()) + @property + def id_number_offset(self) -> GraphIDNumber: + """Get the graph id number offset of the graph id generator for this SFG.""" + return self._graph_id_generator.id_number_offset + @property def components(self) -> Iterable[GraphComponent]: """Get all components of this graph in the dfs-traversal order.""" @@ -209,11 +226,6 @@ class SFG(AbstractOperation): """ return self._components_by_name.get(name, []) - def deep_copy(self) -> "SFG": - """Returns a deep copy of self without any connections.""" - return SFG(inputs = self._input_operations, outputs = self._output_operations, - id_number_offset = self._graph_id_generator.id_number_offset, name = self.name) - def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: assert original_comp not in self._original_components_to_new, "Tried to add duplicate SFG component" new_comp = original_comp.copy_component() diff --git a/test/test_sfg.py b/test/test_signal_flow_graph.py similarity index 77% rename from test/test_sfg.py rename to test/test_signal_flow_graph.py index 1501e6cff3f8adfb09a3a629aea69b68a59a476e..87fc5d9937505ad31f4880a4652629c5499c6ad3 100644 --- a/test/test_sfg.py +++ b/test/test_signal_flow_graph.py @@ -1,6 +1,6 @@ import pytest -from b_asic import SFG, Signal, Input, Output, Addition, Multiplication +from b_asic import SFG, Signal, Input, Output, Constant, Addition, Multiplication class TestConstructor: @@ -67,11 +67,13 @@ class TestDeepCopy: out1 = Output(mul1, "OUT1") mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg_new = mac_sfg() - mac_sfg_deep_copy = mac_sfg.deep_copy() + assert mac_sfg.name == "mac_sfg" + assert mac_sfg_new.name == "" for g_id, component in mac_sfg._components_by_id.items(): - component_copy = mac_sfg_deep_copy.find_by_id(g_id) + component_copy = mac_sfg_new.find_by_id(g_id) assert component.name == component_copy.name def test_deep_copy(self): @@ -91,13 +93,33 @@ class TestDeepCopy: mul1.input(1).connect(add2, "S6") out1.input(0).connect(mul1, "S7") - mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], id_number_offset = 100, name = "mac_sfg") + mac_sfg_new = mac_sfg(name = "mac_sfg2") - mac_sfg_deep_copy = mac_sfg.deep_copy() + assert mac_sfg.name == "mac_sfg" + assert mac_sfg_new.name == "mac_sfg2" + assert mac_sfg.id_number_offset == 100 + assert mac_sfg_new.id_number_offset == 100 for g_id, component in mac_sfg._components_by_id.items(): - component_copy = mac_sfg_deep_copy.find_by_id(g_id) + component_copy = mac_sfg_new.find_by_id(g_id) assert component.name == component_copy.name + + def test_deep_copy_with_new_sources(self): + inp1 = Input("INP1") + inp2 = Input("INP2") + inp3 = Input("INP3") + add1 = Addition(inp1, inp2, "ADD1") + mul1 = Multiplication(add1, inp3, "MUL1") + out1 = Output(mul1, "OUT1") + + mac_sfg = SFG(inputs = [inp1, inp2], outputs = [out1], name = "mac_sfg") + + a = Addition(Constant(3), Constant(5)) + b = Constant(2) + mac_sfg_new = mac_sfg(a, b) + assert mac_sfg_new.input(0).signals[0].source.operation is a + assert mac_sfg_new.input(1).signals[0].source.operation is b class TestComponents: