diff --git a/src/simudator/core/modules/mux.py b/src/simudator/core/modules/mux.py
index 0d50b4463323740d3f8d9e7b4d6fb529ea1e06e5..06a0fdda1631320d8c5679257add6483c7605850 100644
--- a/src/simudator/core/modules/mux.py
+++ b/src/simudator/core/modules/mux.py
@@ -6,23 +6,41 @@ from simudator.core.signal import Signal
 
 class Mux(Module):
     """
-    A general mux that allows an arbitrary amount of input to map
-    to a single output. The input side is controlled by the control
-    signal 'to_mux_s'.
+    A general mux that allows an arbitrary amount of input to map to a single
+    output.
+
+    Parameters
+    ----------
+    control : Signal
+        Control signal of which the value is used to index which
+        input signal to read from.
+    bit_length : int
+        Maximum number of bits for the value outputted by the mux. All extra
+        bits of the input value are discarded.
+    inputs : list[Signal]
+        List of input signals to select from to output onto the output signal.
+    output : Signal
+        Signal onto which the value of the currently selected input is
+        outputted.
+    name : str
+        Name of the multiplexer module.
+    value : Any
+        Initial value to output to the output signal.
     """
 
-    def __init__(self,
-                 to_mux: Signal,
-                 bit_length: int,
-                 output: Signal,
-                 inputs: list[Signal] = [],
-                 name="mux",
-                 value=0
-                 ) -> None:
+    def __init__(
+        self,
+        control: Signal,
+        bit_length: int,
+        output: Signal,
+        inputs: list[Signal] = [],
+        name="mux",
+        value=0,
+    ) -> None:
 
         # Signals
         signals = {
-            "in_control": to_mux,
+            "in_control": control,
             "out": output,
         }
         for i, s in enumerate(inputs):
@@ -33,15 +51,14 @@ class Mux(Module):
 
         # mask and bit_length
         self.bit_length = bit_length
-        self.mask = 2**self.bit_length -1
+        self.mask = 2**self.bit_length - 1
 
         # Value to be read/written
         self.value = value
 
     def update_register(self) -> None:
-        """
-        Read which signal to read from the control signal
-        'to_mux_s' and forward that signal to the output.
+        """Read which signal to read from the control signal and forward that
+        signal to the output.
         """
         input_index = self.signals["in_control"].get_value()
         input_signal_key = f"in_input_{input_index}"
@@ -49,21 +66,19 @@ class Mux(Module):
         self.value = input_value & self.mask
 
     def output_register(self):
-        """
-        Output the value of the mux to its output.
+        """Output the value of the currently selected input signal to the
+        output signal.
         """
         self.signals["out"].update_value(self.value)
 
     def update_logic(self):
-        """
+        """Do nothing.
+
         The mux has no logic.
         """
         pass
 
     def get_state(self) -> dict[str, Any]:
-        """
-        Returns a dict of the mux state.
-        """
         state = super().get_state()
         state["value"] = self.value
         state["bit_length"] = self.bit_length
@@ -71,9 +86,15 @@ class Mux(Module):
 
         return state
 
-    def set_state(self, state: dict) -> None:
-        """
-        Sets the name and value of the mux.
+    def set_state(self, state: dict[str, Any]) -> None:
+        """Set the name, value and bit length of the demux.
+
+        Parameters
+        ----------
+        state : dict[str, Any]
+            The state of the demux to load. Should contain the keys "name",
+            "value" and "bit_length" with values of type ``str``, ``int`` and
+            ``int`` respectively.
         """
         super().set_state(state)
         self.value = state["value"]
@@ -81,8 +102,13 @@ class Mux(Module):
         self.mask = state["mask"]
 
     def print_module(self) -> None:
-        print(self.name, "\n-----",
-              "\nvalue: ", self.value,
-              "\nbit length: ", self.bit_length,
-              "\nmask: ", self.mask,
-              )
+        print(
+            self.name,
+            "\n-----",
+            "\nvalue: ",
+            self.value,
+            "\nbit length: ",
+            self.bit_length,
+            "\nmask: ",
+            self.mask,
+        )