Skip to content
Snippets Groups Projects
Commit 25300864 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Initial work on constant propagation

parent 2a90ea3c
No related branches found
No related tags found
No related merge requests found
Pipeline #88596 passed
This commit is part of merge request !137. Comments created here will be created in the context of that merge request.
......@@ -5,13 +5,13 @@ Contains some of the most commonly used mathematical operations.
"""
from numbers import Number
from typing import Dict, Optional
from typing import Dict, Iterable, Optional, Set
from numpy import abs as np_abs
from numpy import conjugate, sqrt
from b_asic.graph_component import Name, TypeName
from b_asic.operation import AbstractOperation
from b_asic.operation import AbstractOperation, Operation
from b_asic.port import SignalSourceProvider
......@@ -125,6 +125,14 @@ class Addition(AbstractOperation):
def evaluate(self, a, b):
return a + b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class Subtraction(AbstractOperation):
"""
......@@ -185,6 +193,14 @@ class Subtraction(AbstractOperation):
def evaluate(self, a, b):
return a - b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class AddSub(AbstractOperation):
r"""
......@@ -266,6 +282,14 @@ class AddSub(AbstractOperation):
"""Set if operation is add."""
self.set_param("is_add", is_add)
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class Multiplication(AbstractOperation):
r"""
......@@ -323,6 +347,17 @@ class Multiplication(AbstractOperation):
def evaluate(self, a, b):
return a * b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
if any(c == 1.0 for c in constants):
print("One input is 1!")
print("Can turn into ConstantMultiplication")
class Division(AbstractOperation):
r"""
......@@ -361,6 +396,17 @@ class Division(AbstractOperation):
def evaluate(self, a, b):
return a / b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
numerator, denominator = constants
if numerator == 0.0:
print("Result is 0!")
if denominator is not None:
print("Can turn into ConstantMultiplication")
class Min(AbstractOperation):
r"""
......@@ -646,6 +692,14 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class MAD(AbstractOperation):
r"""
......@@ -685,6 +739,19 @@ class MAD(AbstractOperation):
def evaluate(self, a, b, c):
return a * b + c
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
a, b, c = constants
if a == 0.0 or b == 0.0:
print("One multiplier input is zero!")
if a == 1.0 or b == 1.0:
print("One multiplier input is one!")
if any(c == 0.0):
print("Adder input is zero!")
class SymmetricTwoportAdaptor(AbstractOperation):
r"""
......@@ -736,3 +803,11 @@ class SymmetricTwoportAdaptor(AbstractOperation):
def value(self, value: Number) -> None:
"""Set the constant value of this operation."""
self.set_param("value", value)
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
......@@ -19,6 +19,7 @@ from typing import (
NewType,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
......@@ -419,6 +420,24 @@ class Operation(GraphComponent, SignalSourceProvider):
def _check_all_latencies_set(self) -> None:
raise NotImplementedError
@abstractmethod
def _propagate_constants(
self, valid_operations: Optional[Set["Operation"]] = None
) -> "Operation":
raise NotImplementedError
@abstractmethod
def _constant_inputs(self) -> Iterable[Number]:
raise NotImplementedError
@abstractmethod
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""
......@@ -1068,3 +1087,38 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for k in range(len(self.outputs))
]
return input_coords, output_coords
def _propagate_constants(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
# Must be implemented per operation, so just return otherwise
constants = self._constant_inputs()
if all(c is None for c in constants):
return
if all(c is not None for c in constants):
res = self.evalute(*constants)
print(f"Result is {res}!")
if any(c is not None for c in constants):
# This is operation dependent
self._propagate_some_constants(constants, valid_operations)
return
def _constant_inputs(self) -> Iterable[Optional[Number]]:
from b_asic.core_operations import Constant
ret = []
for port in self._input_ports:
if port.connected_source is None:
ret.append(None)
elif isinstance(port.connected_source.operation, Constant):
ret.append(port.connected_source.operation.value)
else:
ret.append(None)
return ret
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment