Skip to content
Snippets Groups Projects
Commit 4cda8b7e authored by Johannes Kung's avatar Johannes Kung
Browse files

Refactor register

parent 3ba1c19b
No related branches found
No related tags found
1 merge request!27Docstring refactor of core
......@@ -72,6 +72,15 @@ class Register(Module):
return state
def set_state(self, state: dict[str, Any]) -> None:
"""
Set the state of the register.
Parameters
----------
state : dict[str, Any]
The state of the register to load. Should contain the keys "name"
and "value" with values of type ``str`` and ``int`` respectively.
"""
self.name = state["name"]
self._value = state["value"]
......@@ -99,45 +108,55 @@ class Register(Module):
class IntegerRegister(Register):
"""
A register inteded to store integers only.
A register intended to store integers only.
Parameters
----------
input : Signal
Signal from which the value is stored in the register.
output : Signal
Signal onto which the value of the register is outputted.
bit_length : int
Maximum number of bits of the input to store in the register. All extra
bits of the input value are discarded.
value : Any
Initial value of the register.
name : str
Name of the register.
"""
def __init__(
self,
input_signal: Signal,
output_signal: Signal,
input: Signal,
output: Signal,
bit_length: int,
value: int = 0,
name: str | None = None,
) -> None:
# set the registers name
# set the name
if name is None:
name = f"{bit_length}-bit register"
super().__init__(input_signal, output_signal, value=value, name=name)
super().__init__(input, output, value=value, name=name)
# set the bit length of the register
self.bit_length = bit_length
# set the registers mask. An 8 bit register should
# have the mask 1111 1111, aka one '1' for every bit
self.mask = 2**bit_length - 1
self._bit_length = bit_length
def update_register(self) -> None:
super().update_register()
self._value = self._value & self.mask
mask = 2**self._bit_length - 1
self._value = self._value & mask
def get_state(self) -> dict[str, Any]:
state = super().get_state()
state["bit_length"] = self.bit_length
state["bit_length"] = self._bit_length
return state
def set_state(self, state: dict[str, Any]) -> None:
super().set_state(state)
if "bit_length" in state:
self.bit_length = state["bit_length"]
self.mask = 2**self.bit_length - 1
self._bit_length = state["bit_length"]
def save_state_to_file(self, file_path: str) -> bool:
file = open(file_path, "a")
......@@ -168,8 +187,8 @@ class Flag(IntegerRegister):
# set the flags name
super().__init__(
input_signal=input_signal,
output_signal=output_signal,
input=input_signal,
output=output_signal,
bit_length=bit_length,
value=value,
name=name,
......@@ -207,7 +226,7 @@ class Flag(IntegerRegister):
"\n value: ",
self._value,
"\n bit length: ",
self.bit_length,
self._bit_length,
"\n mask: ",
self.mask,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment