From 8eeffdd6890eb1aa3d3ee839248c5aeedf6e9f9f Mon Sep 17 00:00:00 2001
From: Johannes Kung <johku144@student.liu.se>
Date: Tue, 25 Jun 2024 11:59:34 +0200
Subject: [PATCH] Corrected faulty MIA LC behaviour

---
 src/simudator/processor/mia/modules/lc.py | 200 ++++++++--------------
 test/test_mia/test_lc.py                  | 126 +++++++++-----
 2 files changed, 156 insertions(+), 170 deletions(-)

diff --git a/src/simudator/processor/mia/modules/lc.py b/src/simudator/processor/mia/modules/lc.py
index 357e46b..a84c10c 100644
--- a/src/simudator/processor/mia/modules/lc.py
+++ b/src/simudator/processor/mia/modules/lc.py
@@ -10,17 +10,43 @@ class LC(Module):
     """A class representing the loop counter. It is controlled by the
     signal uM_control which determines if it should read from bus,
     increase by one or read from mM_uADR.
+
+    Parameters
+    ----------
+    mM_control : Signal
+        A signal connection from the micro memory
+        to the loop counter. This allows the micro memory to send bit
+        12 and 13 to the loop counter. Bit 12 and 13 decides the
+        loop counters behaviour.
+    bus_input : Signal
+        A signal connection from the bus to the loop
+        counter. The loop counter reads from this signal when it reads
+        from the bus.
+    l_flag : Signal
+        A signal connection from the loop counter to the
+        L-flag. The loop counter writes to this signal when it needs to
+        update the L-flag.
+    mM_uADR : Signal
+        A signal connection from the micro memory to
+        the loop counter. This allows the loop counter to read the 7
+        least significant bits from the uADR field.
+    name : str
+        Optional name of the loop counter.
+    value : int
+        Optional start value of the loop counter.
     """
 
     __slots__ = (
-        "value",
-        "read_from_bus",
-        "read_from_uADR",
-        "decrement_by_one",
-        "bit_length",
-        "mask",
+        "_value",
+        "_set_l_flag",
     )
 
+    BIT_LENGTH = 8
+    """
+    Number of bits used to store the loop counter value. Used to truncate too
+    large input.
+    """
+
     def __init__(
         self,
         mM_control: Signal,
@@ -29,35 +55,7 @@ class LC(Module):
         mM_uADR: Signal,
         name="LC",
         value=0,
-        bit_length=8,
     ) -> None:
-        """
-        Parameters
-        ----------
-        mM_control : Signal
-            A signal connection from the micro memory
-            to the loop counter. This allows the micro memory to send bit
-            12 and 13 to the loop counter. Bit 12 and 13 decides the
-            loop counters behaviour.
-        bus_input : Signal
-            A signal connection from the bus to the loop
-            counter. The loop counter reads from this signal when it reads
-            from the bus.
-        l_flag : Signal
-            A signal connection from the loop counter to the
-            L flag. The loop counter writes to this signal when it needs to
-            update the L flag.
-        mM_uADR : Signal
-            A signal connection from the micro memory to
-            the loop counter. This allows the loop counter to read the 7
-            least significant bits from the uADR field.
-        name : str
-            Optional name of the loop counter.
-        value : int
-            Optional start value of the loop counter.
-        bit_length : int
-            Optional bit length of the loop counter.
-        """
 
         # signals
         signals = {
@@ -70,127 +68,85 @@ class LC(Module):
         super().__init__(signals, name)
 
         # the value of the loop counter
-        self.value = value
+        self._value = value
 
-        # helper variables
-        self.read_from_bus = False
-        self.read_from_uADR = False
-        self.decrement_by_one = False
-
-        # bit length and mask
-        self.bit_length = bit_length
-        self.mask = 2**self.bit_length - 1
+        # helper to correctly set and reset the L-flag
+        self._set_l_flag = False
 
     def update_register(self) -> None:
-        """Reads bit 12 and 13 from the micro memory and updates the
-        loop counter.
-        0 - does nothing.
-        1 - decrements the loop counter by one.
-        2 - loads the 8 least significant bits from the bus.
-        3 - loads the 7 least significant bits from the uADR field.
+        """Read bit 12 and 13 from the micro memory and update the
+        loop counter accordingly.
+
+        0 - do nothing.
+        1 - decrement the loop counter by one.
+        2 - load the 8 least significant bits from the bus.
+        3 - load the 7 least significant bits from the uADR field.
         """
 
+        bit_mask = 2**LC.BIT_LENGTH - 1
         match self.signals["in_control"].get_value():
-            case 0b00:  # LC is not effected
-                self.decrement_by_one = False
-                self.read_from_bus = False
-                self.read_from_uADR = False
+            case 0b00:  # Do nothing
+                return
 
             case 0b01:  # Decrement by one
-                self.decrement_by_one = True
-                self.read_from_bus = False
-                self.read_from_uADR = False
-
-            case 0b10:  # Load 8 least significant bits from bus
-                self.decrement_by_one = False
-                self.read_from_bus = True
-                self.read_from_uADR = False
-
-            case 0b11:  # LCs 7 least significant bits are loaded from uADR
-                self.decrement_by_one = False
-                self.read_from_bus = False
-                self.read_from_uADR = True
+                self._value -= 1
 
-        if self.read_from_bus:
-            input_value = self.signals["in_input"].get_value()
-            self.value = input_value & self.mask
+                if self._value < 0:
+                    # underflow correctly
+                    # (the bit mask is the same as the maximum value)
+                    self._value = bit_mask
 
-        if self.read_from_uADR:
-            input_value = self.signals["in_address"].get_value()
-            self.value = input_value & self.mask
+            case 0b10:  # Load 8 least significant bits from bus
+                self._value = self.signals["in_input"].get_value()
 
-        if self.decrement_by_one:
-            self.value -= 1
+            case 0b11:  # Load 7 least significant bits from uADR
+                self._value = self.signals["in_address"].get_value()
 
-            # overflow correctly
-            if self.value < 0:
-                self.value = self.mask
+        # Truncate if the value is too large and determine if the L-flag
+        # should be set
+        self._value = self._value & bit_mask
+        self._set_l_flag = self._value == 0
 
     def output_register(self) -> None:
-        """The loop counter will only output to the L flag, this is
-        handled in 'update_logic'.
-        """
-        pass
-
-    def update_logic(self) -> None:
-        """When the loop counter reaches zero, set the l flag to 1.
-        Otherwise set it to zero.
+        """Set the L-flag to 1 if the loop counter has reached zero. Set the
+        L-flag to 0 if the loop counter is not zero.
         """
-        if self.value == 0:
+        if self._set_l_flag:
             self.signals["out_flag_l"].update_value(1)
         else:
             self.signals["out_flag_l"].update_value(0)
 
-    def get_state(self) -> dict[str, Any]:
-        """Returns a dict of the loop counter state.
-        These states are changable via set_states.
+    def update_logic(self) -> None:
+        """Do nothing.
 
-        Returns
-        -------
-        dict[Any]
-            The state of the loop counter.
+        The loop counter has no internal logic.
         """
+        pass
+
+    def get_state(self) -> dict[str, Any]:
         state = dict()
         state["name"] = self.name
-        state["value"] = self.value
-        state["bit_length"] = self.bit_length
-        state["mask"] = self.mask
-        state["read_from_bus"] = self.read_from_bus
-        state["read_from_uADR"] = self.read_from_uADR
-        state["decrement_by_one"] = self.decrement_by_one
+        state["value"] = self._value
+        state["bit_length"] = LC.BIT_LENGTH
         return state
 
     def set_state(self, state: dict[str, Any]) -> None:
-        """Sets the loop counter state to one given in dict."""
         self.name = state["name"]
-        self.value = state["value"]
-        if "bit_length" in state:
-            self.bit_length = state["bit_length"]
-        if "mask" in state:
-            self.mask = state["mask"]
-        if "read_from_bus" in state:
-            self.read_from_bus = state["read_from_bus"]
-        if "read_from_uADR" in state:
-            self.read_from_uADR = state["read_from_uADR"]
-        if "decrement_by_one" in state:
-            self.decrement_by_one = state["decrement_by_one"]
+        self._value = state["value"]
 
     def reset(self) -> None:
-        """Resets the loop counter to 0."""
-        self.value = 0
-        self.read_from_bus = False
-        self.read_from_uADR = False
-        self.decrement_by_one = False
+        """Reset the loop counter to 0."""
+        self._value = 0
 
     def save_state_to_file(self, file_path: str) -> bool:
         content = self.name + ":\n"
-        content += "value: " + hex(self.value)[2:] + "\n\n"
+        content += "value: " + hex(self._value)[2:] + "\n\n"
         return super()._helper_save_state_to_file(file_path, content)
 
     def load_from_str(self, state_string):
         string_pair = state_string.split(": ")
         # TODO: Maybe check if it starts with value: ?
-        self.value = int(string_pair[1], 16)
+        self._value = int(string_pair[1], 16)
 
     def print_module(self) -> None:
         print(
@@ -198,11 +154,5 @@ class LC(Module):
             self.name,
             "\n -----",
             "\n value: ",
-            hex(self.value),
-            "\n decrement: ",
-            self.decrement_by_one,
-            "\n read from uADR: ",
-            self.read_from_uADR,
-            "\n read from bus: ",
-            self.read_from_bus,
+            hex(self._value),
         )
diff --git a/test/test_mia/test_lc.py b/test/test_mia/test_lc.py
index 3baf1b8..f087760 100644
--- a/test/test_mia/test_lc.py
+++ b/test/test_mia/test_lc.py
@@ -1,4 +1,3 @@
-from simudator.core.modules.integer_register import IntegerRegister
 from simudator.core.processor import Processor
 from simudator.core.signal import Signal
 from simudator.processor.mia.modules.lc import LC
@@ -28,11 +27,37 @@ def test_read_from_bus():
 
     cpu.add_module(lc)
 
-    bus_input_s.update_value(10)
+    # Set the LC to read from the bus
     mM_control_s.update_value(2)
 
+    # Reading a non-zero value should set the L-flag to 0
+    bus_input_s.update_value(10)
+    cpu.do_tick()
+    assert lc._value == 10
+    assert l_flag_s.get_value() == 0
+
+    # Reading a zero while having a non-zero value should set the L-flag to 1
+    bus_input_s.update_value(0)
+    cpu.do_tick()
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
+
+    # Reading a zero again should not reset the L-flag
     cpu.do_tick()
-    assert lc.value == 10
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
+
+    # Reading a non-zero value now should reset the L-flag
+    bus_input_s.update_value(2)
+    cpu.do_tick()
+    assert lc._value == 2
+    assert l_flag_s.get_value() == 0
+
+    # Check that the inputted value is correctly truncated
+    bus_input_s.update_value(0xFF1A)
+    cpu.do_tick()
+    assert lc._value == 0x1A
+    assert l_flag_s.get_value() == 0
 
 
 def test_read_from_uADR():
@@ -45,14 +70,40 @@ def test_read_from_uADR():
 
     cpu.add_module(lc)
 
+    # Set the LC to read from uADR
+    mM_control_s.update_value(0b11)
+
+    # Reading a non-zero value should set the L-flag to 0
     mM_uADR_s.update_value(10)
-    mM_control_s.update_value(3)
+    cpu.do_tick()
+    assert lc._value == 10
+    assert l_flag_s.get_value() == 0
+
+    # Reading a zero while having a non-zero value should set the L-flag to 1
+    mM_uADR_s.update_value(0)
+    cpu.do_tick()
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
 
+    # Reading a zero again should not reset the L-flag
     cpu.do_tick()
-    assert lc.value == 10
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
+
+    # Reading a non-zero value now should reset the L-flag
+    mM_uADR_s.update_value(2)
+    cpu.do_tick()
+    assert lc._value == 2
+    assert l_flag_s.get_value() == 0
+
+    # Check that the inputted value is correctly truncated
+    mM_uADR_s.update_value(0xFF1A)
+    cpu.do_tick()
+    assert lc._value == 0x1A
+    assert l_flag_s.get_value() == 0
 
 
-def test_write_to_l_flag():
+def test_decrement():
     cpu = Processor()
 
     # signals needed
@@ -63,7 +114,7 @@ def test_write_to_l_flag():
 
     # modules needed to run the test
     # initialize loop counter to 1, so when it is decremented by
-    # one it reaches zero and sets the L flag
+    # one it reaches zero and sets the L-flag
     lc_value = 1
     lc = LC(mM_control_s, bus_input_s, l_flag_s, mM_uADR_s, "lc", lc_value)
 
@@ -75,31 +126,15 @@ def test_write_to_l_flag():
     mM_control_s.update_value(0b01)
 
     # The loop counter will now change its value from 1 to 0 which
-    # should set the L flag to 1 in the same cycle
+    # should set the L-flag to 1 in the same cycle
     cpu.do_tick()
+    assert lc._value == 0
     assert l_flag_s.get_value() == 1
 
-
-def test_reset_l_flag():
-    cpu = Processor()
-
-    mM_control_s = Signal(cpu)
-    mM_uADR_s = Signal(cpu)
-    l_flag_s = Signal(cpu)
-    bus_input_s = Signal(cpu)
-
-    lc = LC(mM_control_s, bus_input_s, l_flag_s, mM_uADR_s, "lc", 1)
-
-    cpu.add_module(lc)
-
-    mM_uADR_s.update_value(0)
-    mM_control_s.update_value(3)
-
-    cpu.do_tick()
-
-    mM_uADR_s.update_value(10)
-
+    # Decrementing again should make the LC wrap around to 0xFF and reset the
+    # L-flag
     cpu.do_tick()
+    assert lc._value == 0xFF
     assert l_flag_s.get_value() == 0
 
 
@@ -119,13 +154,29 @@ def test_lc_do_nothing():
     mM_control_s.update_value(0b01)
 
     cpu.do_tick()
-    assert lc.value == 3
+    assert lc._value == 3
+    assert l_flag_s.get_value() == 0
 
     # tell lc to do nothing
     mM_control_s.update_value(0b00)
 
     cpu.do_tick()
-    assert lc.value == 3
+    assert lc._value == 3
+    assert l_flag_s.get_value() == 0
+
+    # decrement to zero and then do nothing to check that the L-flag remains
+    # set to 1
+    mM_control_s.update_value(0b01)
+    cpu.do_tick()
+    cpu.do_tick()
+    cpu.do_tick()
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
+
+    mM_control_s.update_value(0b00)
+    cpu.do_tick()
+    assert lc._value == 0
+    assert l_flag_s.get_value() == 1
 
 
 def test_get_state():
@@ -147,11 +198,6 @@ def test_get_state():
     state = lc.get_state()
     assert state["name"] == "LC"
     assert state["value"] == 255
-    assert state["mask"] == 255
-    assert state["bit_length"] == 8
-    assert state["read_from_bus"] is False
-    assert state["read_from_uADR"] is False
-    assert state["decrement_by_one"] is True
 
     mM_control_s.update_value(0b10)
     cpu.do_tick()
@@ -159,11 +205,6 @@ def test_get_state():
     state = lc.get_state()
     assert state["name"] == "LC"
     assert state["value"] == 100
-    assert state["mask"] == 255
-    assert state["bit_length"] == 8
-    assert state["read_from_bus"] is True
-    assert state["read_from_uADR"] is False
-    assert state["decrement_by_one"] is False
 
     mM_control_s.update_value(0b11)
     cpu.do_tick()
@@ -171,8 +212,3 @@ def test_get_state():
     state = lc.get_state()
     assert state["name"] == "LC"
     assert state["value"] == 10
-    assert state["mask"] == 255
-    assert state["bit_length"] == 8
-    assert state["read_from_bus"] is False
-    assert state["read_from_uADR"] is True
-    assert state["decrement_by_one"] is False
-- 
GitLab