diff --git a/b_asic/save_load_structure.py b/b_asic/save_load_structure.py index 6b424dafb4401ac93f55df7313e14261fed8bdb5..b3970dc5f4822d25fbaa772e7fac570138a337dd 100644 --- a/b_asic/save_load_structure.py +++ b/b_asic/save_load_structure.py @@ -7,13 +7,16 @@ as files. from datetime import datetime from inspect import signature +from typing import Dict, Optional, Tuple, cast from b_asic.graph_component import GraphComponent +from b_asic.port import InputPort from b_asic.signal_flow_graph import SFG -from b_asic.special_operations import Input, Output -def sfg_to_python(sfg: SFG, counter: int = 0, suffix: str = None) -> str: +def sfg_to_python( + sfg: SFG, counter: int = 0, suffix: Optional[str] = None +) -> str: """ Given an SFG structure try to serialize it for saving to a file. @@ -39,20 +42,20 @@ def sfg_to_python(sfg: SFG, counter: int = 0, suffix: str = None) -> str: ) result += "\nfrom b_asic import SFG, Signal, Input, Output" - for op in {type(op) for op in sfg.operations}: - result += f", {op.__name__}" + for op_type in {type(op) for op in sfg.operations}: + result += f", {op_type.__name__}" def kwarg_unpacker(comp: GraphComponent, params=None) -> str: if params is None: params_filtered = { - attr: getattr(op, attr) - for attr in signature(op.__init__).parameters - if attr != "latency" and hasattr(op, attr) + attr: getattr(comp, attr) + for attr in signature(comp.__init__).parameters + if attr != "latency" and hasattr(comp, attr) } params = { - attr: getattr(op, attr) - if not isinstance(getattr(op, attr), str) - else f'"{getattr(op, attr)}"' + attr: getattr(comp, attr) + if not isinstance(getattr(comp, attr), str) + else f'"{getattr(comp, attr)}"' for attr in params_filtered } @@ -64,12 +67,14 @@ def sfg_to_python(sfg: SFG, counter: int = 0, suffix: str = None) -> str: io_ops = [*sfg._input_operations, *sfg._output_operations] result += "\n# Inputs:\n" - for op in sfg._input_operations: - result += f"{op.graph_id} = Input({kwarg_unpacker(op)})\n" + for input_op in sfg._input_operations: + result += f"{input_op.graph_id} = Input({kwarg_unpacker(input_op)})\n" result += "\n# Outputs:\n" - for op in sfg._output_operations: - result += f"{op.graph_id} = Output({kwarg_unpacker(op)})\n" + for output_op in sfg._output_operations: + result += ( + f"{output_op.graph_id} = Output({kwarg_unpacker(output_op)})\n" + ) result += "\n# Operations:\n" for op in sfg.split(): @@ -90,10 +95,11 @@ def sfg_to_python(sfg: SFG, counter: int = 0, suffix: str = None) -> str: for op in sfg.split(): for out in op.outputs: for signal in out.signals: - dest_op = signal.destination.operation + destination = cast(InputPort, signal.destination) + dest_op = destination.operation connection = ( f"\nSignal(source={op.graph_id}.output({op.outputs.index(signal.source)})," - f" destination={dest_op.graph_id}.input({dest_op.inputs.index(signal.destination)}))" + f" destination={dest_op.graph_id}.input({dest_op.inputs.index(destination)}))" ) if connection in connections: continue @@ -123,7 +129,7 @@ def sfg_to_python(sfg: SFG, counter: int = 0, suffix: str = None) -> str: return result -def python_to_sfg(path: str) -> SFG: +def python_to_sfg(path: str) -> Tuple[SFG, Dict[str, Tuple[int, int]]]: """ Given a serialized file try to deserialize it and load it to the library.