From 5ee9e6f7276308af07e748d0992389a3eff94024 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ivar=20H=C3=A4rnqvist?= <ivarhar@outlook.com>
Date: Thu, 9 Apr 2020 21:04:43 +0200
Subject: [PATCH] add interface for truncating input signals

---
 b_asic/operation.py          | 25 ++++++++++++++++++++++++-
 b_asic/signal.py             |  3 +--
 b_asic/signal_flow_graph.py  |  4 ++--
 b_asic/special_operations.py |  5 ++---
 4 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/b_asic/operation.py b/b_asic/operation.py
index 0f664f6d..52b9054f 100644
--- a/b_asic/operation.py
+++ b/b_asic/operation.py
@@ -8,6 +8,7 @@ import collections
 from abc import abstractmethod
 from numbers import Number
 from typing import List, Sequence, Iterable, MutableMapping, Optional, Any, Set, Union
+from math import trunc
 
 from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
 from b_asic.port import SignalSourceProvider, InputPort, OutputPort
@@ -162,6 +163,12 @@ class AbstractOperation(Operation, AbstractGraphComponent):
                 if src is not None:
                     self._input_ports[i].connect(src.source)
 
+    def truncate_input(self, index: int, value: Number, bits: int):
+        n = value
+        if not isinstance(n, int):
+            n = trunc(value)
+        return n & ((2 ** bits) - 1)
+
     @abstractmethod
     def evaluate(self, *inputs) -> Any:  # pylint: disable=arguments-differ
         """Evaluate the operation and generate a list of output values given a
@@ -179,6 +186,20 @@ class AbstractOperation(Operation, AbstractGraphComponent):
 
         results[key] = None
         return None
+    
+    def _truncate_inputs(self, input_values: Sequence[Number]):
+        args = []
+        for i in range(self.input_count):
+            input_port = self.input(i)
+            if input_port.signal_count >= 1:
+                bits = input_port.signals[0].bits
+                if bits is None:
+                    args.append(input_values[i])
+                else:
+                    args.append(self.truncate_input(i, input_values[i], bits))
+            else:
+                args.append(input_values[i])
+        return args
 
     def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Addition, ConstantAddition]":
         # Import here to avoid circular imports.
@@ -253,6 +274,8 @@ class AbstractOperation(Operation, AbstractGraphComponent):
     def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableMapping[str, Optional[Number]]] = None, registers: Optional[MutableMapping[str, Number]] = None, prefix: str = "") -> Number:
         if index < 0 or index >= self.output_count:
             raise IndexError(f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
+        if len(input_values) != self.input_count:
+            raise ValueError(f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})")
         if results is None:
             results = {}
         if registers is None:
@@ -261,7 +284,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
         result = self._find_result(prefix, index, results)
         if result is not None:
             return result
-        values = self.evaluate(*input_values)
+        values = self.evaluate(*self._truncate_inputs(input_values))
         if isinstance(values, collections.Sequence):
             if len(values) != self.output_count:
                 raise RuntimeError(f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})")
diff --git a/b_asic/signal.py b/b_asic/signal.py
index c3e9183d..d322f161 100644
--- a/b_asic/signal.py
+++ b/b_asic/signal.py
@@ -15,8 +15,7 @@ class Signal(AbstractGraphComponent):
     _source: Optional["OutputPort"]
     _destination: Optional["InputPort"]
 
-    def __init__(self, source: Optional["OutputPort"] = None, \
-                 destination: Optional["InputPort"] = None, bits: Optional[int] = None, name: Name = ""):
+    def __init__(self, source: Optional["OutputPort"] = None, destination: Optional["InputPort"] = None, bits: Optional[int] = None, name: Name = ""):
         super().__init__(name)
         self._source = None
         self._destination = None
diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py
index b7919f64..3d0ba6cb 100644
--- a/b_asic/signal_flow_graph.py
+++ b/b_asic/signal_flow_graph.py
@@ -183,7 +183,7 @@ class SFG(AbstractOperation):
             return result
         
         # Set the values of our input operations to the given input values.
-        for op, arg in zip(self._input_operations, input_values):
+        for op, arg in zip(self._input_operations, self._truncate_inputs(input_values)):
             op.value = arg
         
         value = self._evaluate_source(self._output_operations[index].input(0).signals[0].source, results, registers, prefix)
@@ -197,7 +197,7 @@ class SFG(AbstractOperation):
     def id_number_offset(self) -> GraphIDNumber:
         """Get the graph id number offset of the graph id generator for this SFG."""
         return self._graph_id_generator.id_number_offset
-    
+
     @property
     def components(self) -> Iterable[GraphComponent]:
         """Get all components of this graph in the dfs-traversal order."""
diff --git a/b_asic/special_operations.py b/b_asic/special_operations.py
index a5a3e90f..140fa410 100644
--- a/b_asic/special_operations.py
+++ b/b_asic/special_operations.py
@@ -82,11 +82,10 @@ class Register(AbstractOperation):
         
         if prefix in results:
             return results[prefix]
-
         if prefix in registers:
             return registers[prefix]
-
+        
         value = registers.get(prefix, self.param("initial_value"))
-        registers[prefix] = input_values[0]
+        registers[prefix] = self._truncate_inputs(input_values)[0]
         results[prefix] = value
         return value
\ No newline at end of file
-- 
GitLab