From 1120d83b2e7e8bc59ed6afcd65740c4cd07ba05f Mon Sep 17 00:00:00 2001
From: TheZoq2 <frans.skarman@protonmail.com>
Date: Wed, 15 Feb 2023 17:40:37 +0100
Subject: [PATCH] "Working" but untested unfolding

---
 b_asic/signal_flow_graph.py | 118 ++++++++++++++++++++++++++++++++++++
 examples/twotapfirsfg.py    |  10 +--
 test/test_sfg.py            |   6 ++
 3 files changed, 129 insertions(+), 5 deletions(-)

diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index 417baa24..c5883722 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -1115,6 +1115,11 @@ class SFG(AbstractOperation):
         return new_component
 
     def _add_operation_connected_tree_copy(self, start_op: Operation) -> None:
+        print(
+            "Running _add_operation_connected_tree_copy with"
+            f" {self._operations_dfs_order}"
+        )
+        print(f"Start op: {start_op}")
         op_stack = deque([start_op])
         while op_stack:
             original_op = op_stack.pop()
@@ -1486,3 +1491,116 @@ class SFG(AbstractOperation):
         from b_asic.schedule import Schedule
 
         return Schedule(self, scheduling_algorithm="ASAP").schedule_time
+
+    def unfold(self, factor: int) -> "SFG":
+        if factor == 0:
+            raise ValueError("Unrollnig 0 times removes the SFG")
+
+        # Make `factor` copies of the sfg
+        new_ops = [
+            [cast(Operation, op.copy_component()) for op in self.operations]
+            for _ in range(factor)
+        ]
+
+        id_idx_map = {
+            op.graph_id: idx for (idx, op) in enumerate(self.operations)
+        }
+
+        # The rest of the process is easier if we clear the connections of the inputs
+        # and outputs of all operations
+        for list in new_ops:
+            for op in list:
+                for input in op.inputs:
+                    input.clear()
+                for output in op.outputs:
+                    output.clear()
+
+        # Walk through the operations, replacing delay nodes with connections
+        for layer in range(factor):
+            for op_idx, op in enumerate(self.operations):
+                new_ops[layer][
+                    op_idx
+                ].name = f"{new_ops[layer][op_idx].name}_{factor-layer}"
+                # NOTE: These are overwritten later, but it's useful to debug with them
+                new_ops[layer][op_idx].graph_id = GraphID(
+                    f"{new_ops[layer][op_idx].graph_id}_{factor-layer}"
+                )
+                if isinstance(op, Delay):
+                    # Port of the operation feeding into this delay
+                    source_port = op.inputs[0].connected_source
+                    if source_port is None:
+                        raise ValueError("Dangling delay input port in sfg")
+
+                    source_op_idx = id_idx_map[source_port.operation.graph_id]
+                    source_op_output_index = source_port.index
+                    new_source_op = new_ops[layer][source_op_idx]
+                    source_op_output = new_source_op.outputs[
+                        source_op_output_index
+                    ]
+
+                    # If this is the last layer, we need to create a new delay element and connect it instead
+                    # of the copied port
+                    if layer == factor - 1:
+                        delay = Delay(name=op.name)
+                        delay.graph_id = op.graph_id
+
+                        # Since we're adding a new operation instead of bypassing as in the
+                        # common case, we also need to hook up the inputs to the delay.
+                        delay.inputs[0].connect(source_op_output)
+
+                        new_source_op = delay
+                        new_source_port = new_source_op.outputs[0]
+                    else:
+                        # The new output port we should connect to
+                        new_source_port = source_op_output
+                        new_source_port.clear()
+
+                    for out_signal in op.outputs[0].signals:
+                        sink_port = out_signal.destination
+                        if sink_port is None:
+                            # It would be weird if we found a signal but it wasn't connected anywere
+                            raise ValueError("Dangling output port in sfg")
+
+                        sink_op_idx = id_idx_map[sink_port.operation.graph_id]
+                        sink_op_output_index = sink_port.index
+
+                        target_layer = 0 if layer == factor - 1 else layer + 1
+
+                        new_dest_op = new_ops[target_layer][sink_op_idx]
+                        new_destination = new_dest_op.inputs[
+                            sink_op_output_index
+                        ]
+                        new_destination.clear()
+                        new_destination.connect(new_source_port)
+                else:
+                    # Other opreations need to be re-targeted to the corresponding output in the
+                    # current layer, as long as that output is not a delay, as that has been solved
+                    # above.
+                    # To avoid double connections, we'll only re-connect inputs
+                    for input_num, original_input in enumerate(op.inputs):
+                        original_source = original_input.connected_source
+                        # We may not always have something connected to the input, if we don't
+                        # we can abort
+                        if original_source is None:
+                            continue
+
+                        # delay connections are handled elsewhere
+                        if not isinstance(original_source.operation, Delay):
+                            source_op_idx = id_idx_map[
+                                original_source.operation.graph_id
+                            ]
+                            source_op_output_idx = original_source.index
+
+                            target_output = new_ops[layer][
+                                source_op_idx
+                            ].outputs[source_op_output_idx]
+
+                            new_ops[layer][op_idx].inputs[input_num].connect(
+                                target_output
+                            )
+
+        all_ops = [op for op_list in new_ops for op in op_list]
+        all_inputs = [op for op in all_ops if isinstance(op, Input)]
+        all_outputs = [op for op in all_ops if isinstance(op, Output)]
+
+        return SFG(inputs=all_inputs, outputs=all_outputs)
diff --git a/examples/twotapfirsfg.py b/examples/twotapfirsfg.py
index e111e2a3..14377648 100644
--- a/examples/twotapfirsfg.py
+++ b/examples/twotapfirsfg.py
@@ -17,18 +17,18 @@ from b_asic import (
 in1 = Input(name="in1")
 
 # Outputs:
-out1 = Output(name="")
+out1 = Output(name="out1")
 
 # Operations:
-t1 = Delay(initial_value=0, name="")
+t1 = Delay(initial_value=0, name="t1")
 cmul1 = ConstantMultiplication(
-    value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None}
+    value=0.5, name="cmul1", latency_offsets={'in0': None, 'out0': None}
 )
 add1 = Addition(
-    name="", latency_offsets={'in0': None, 'in1': None, 'out0': None}
+    name="add1", latency_offsets={'in0': None, 'in1': None, 'out0': None}
 )
 cmul2 = ConstantMultiplication(
-    value=0.5, name="cmul", latency_offsets={'in0': None, 'out0': None}
+    value=0.5, name="cmul2", latency_offsets={'in0': None, 'out0': None}
 )
 
 # Signals:
diff --git a/test/test_sfg.py b/test/test_sfg.py
index a4276fa6..02f4c492 100644
--- a/test/test_sfg.py
+++ b/test/test_sfg.py
@@ -1595,3 +1595,9 @@ class TestCriticalPath:
 
         sfg_simple_accumulator.set_latency_of_type(Addition.type_name(), 6)
         assert sfg_simple_accumulator.critical_path() == 6
+
+
+class TestUnroll:
+    def unrolling_by_factor_0_raises(self, sfg_simple_filter: SFG):
+        with pytest.raises(ValueError):
+            sfg_simple_filter.unfold(0)
-- 
GitLab