From 253008644bf4587d9fe30ad196e2baa091677fa0 Mon Sep 17 00:00:00 2001 From: Oscar Gustafsson <oscar.gustafsson@gmail.com> Date: Wed, 1 Feb 2023 18:10:45 +0100 Subject: [PATCH] Initial work on constant propagation --- b_asic/core_operations.py | 79 ++++++++++++++++++++++++++++++++++++++- b_asic/operation.py | 54 ++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index 5b74b83e..0ed9679a 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -5,13 +5,13 @@ Contains some of the most commonly used mathematical operations. """ from numbers import Number -from typing import Dict, Optional +from typing import Dict, Iterable, Optional, Set from numpy import abs as np_abs from numpy import conjugate, sqrt from b_asic.graph_component import Name, TypeName -from b_asic.operation import AbstractOperation +from b_asic.operation import AbstractOperation, Operation from b_asic.port import SignalSourceProvider @@ -125,6 +125,14 @@ class Addition(AbstractOperation): def evaluate(self, a, b): return a + b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class Subtraction(AbstractOperation): """ @@ -185,6 +193,14 @@ class Subtraction(AbstractOperation): def evaluate(self, a, b): return a - b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class AddSub(AbstractOperation): r""" @@ -266,6 +282,14 @@ class AddSub(AbstractOperation): """Set if operation is add.""" self.set_param("is_add", is_add) + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class Multiplication(AbstractOperation): r""" @@ -323,6 +347,17 @@ class Multiplication(AbstractOperation): def evaluate(self, a, b): return a * b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + if any(c == 1.0 for c in constants): + print("One input is 1!") + print("Can turn into ConstantMultiplication") + class Division(AbstractOperation): r""" @@ -361,6 +396,17 @@ class Division(AbstractOperation): def evaluate(self, a, b): return a / b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + numerator, denominator = constants + if numerator == 0.0: + print("Result is 0!") + if denominator is not None: + print("Can turn into ConstantMultiplication") + class Min(AbstractOperation): r""" @@ -646,6 +692,14 @@ class Butterfly(AbstractOperation): def evaluate(self, a, b): return a + b, a - b + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") + class MAD(AbstractOperation): r""" @@ -685,6 +739,19 @@ class MAD(AbstractOperation): def evaluate(self, a, b, c): return a * b + c + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + a, b, c = constants + if a == 0.0 or b == 0.0: + print("One multiplier input is zero!") + if a == 1.0 or b == 1.0: + print("One multiplier input is one!") + if any(c == 0.0): + print("Adder input is zero!") + class SymmetricTwoportAdaptor(AbstractOperation): r""" @@ -736,3 +803,11 @@ class SymmetricTwoportAdaptor(AbstractOperation): def value(self, value: Number) -> None: """Set the constant value of this operation.""" self.set_param("value", value) + + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + if any(c == 0.0 for c in constants): + print("One input is 0!") diff --git a/b_asic/operation.py b/b_asic/operation.py index f43d7d52..931d560d 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -19,6 +19,7 @@ from typing import ( NewType, Optional, Sequence, + Set, Tuple, Union, cast, @@ -419,6 +420,24 @@ class Operation(GraphComponent, SignalSourceProvider): def _check_all_latencies_set(self) -> None: raise NotImplementedError + @abstractmethod + def _propagate_constants( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> "Operation": + raise NotImplementedError + + @abstractmethod + def _constant_inputs(self) -> Iterable[Number]: + raise NotImplementedError + + @abstractmethod + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + raise NotImplementedError + class AbstractOperation(Operation, AbstractGraphComponent): """ @@ -1068,3 +1087,38 @@ class AbstractOperation(Operation, AbstractGraphComponent): for k in range(len(self.outputs)) ] return input_coords, output_coords + + def _propagate_constants( + self, valid_operations: Optional[Set["Operation"]] = None + ) -> None: + # Must be implemented per operation, so just return otherwise + constants = self._constant_inputs() + if all(c is None for c in constants): + return + if all(c is not None for c in constants): + res = self.evalute(*constants) + print(f"Result is {res}!") + if any(c is not None for c in constants): + # This is operation dependent + self._propagate_some_constants(constants, valid_operations) + return + + def _constant_inputs(self) -> Iterable[Optional[Number]]: + from b_asic.core_operations import Constant + + ret = [] + for port in self._input_ports: + if port.connected_source is None: + ret.append(None) + elif isinstance(port.connected_source.operation, Constant): + ret.append(port.connected_source.operation.value) + else: + ret.append(None) + return ret + + def _propagate_some_constants( + self, + constants: Iterable[Optional[Number]], + valid_operations: Optional[Set["Operation"]] = None, + ) -> None: + return -- GitLab