From dd7aaafa9086436c6df4b1fb9e66cd707fe44daf Mon Sep 17 00:00:00 2001
From: Oscar Gustafsson <oscar.gustafsson@gmail.com>
Date: Mon, 6 Feb 2023 19:56:54 +0100
Subject: [PATCH] Add division support for generators

---
 b_asic/signal_generator.py    | 25 +++++++++++++++++++++++++
 test/test_signal_generator.py | 25 +++++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/b_asic/signal_generator.py b/b_asic/signal_generator.py
index be4543eb..8270736f 100644
--- a/b_asic/signal_generator.py
+++ b/b_asic/signal_generator.py
@@ -49,6 +49,16 @@ class SignalGenerator:
             return MultGenerator(self, Constant(other))
         return MultGenerator(self, other)
 
+    def __truediv__(self, other) -> "MulGenerator":
+        if isinstance(other, Number):
+            return DivGenerator(self, Constant(other))
+        return DivGenerator(self, other)
+
+    def __rtruediv__(self, other) -> "MulGenerator":
+        if isinstance(other, Number):
+            return DivGenerator(Constant(other), self)
+        return DivGenerator(other, self)
+
 
 class Impulse(SignalGenerator):
     """
@@ -187,3 +197,18 @@ class MultGenerator:
 
     def __call__(self, time: int) -> complex:
         return self._a(time) * self._b(time)
+
+
+class DivGenerator:
+    """
+    Signal generator that divides two signals.
+    """
+
+    def __init__(
+        self, a: SignalGenerator, b: SignalGenerator
+    ) -> Callable[[int], complex]:
+        self._a = a
+        self._b = b
+
+    def __call__(self, time: int) -> complex:
+        return self._a(time) / self._b(time)
diff --git a/test/test_signal_generator.py b/test/test_signal_generator.py
index 06051208..9492de37 100644
--- a/test/test_signal_generator.py
+++ b/test/test_signal_generator.py
@@ -127,3 +127,28 @@ def test_multiplication():
     assert g(1) == pytest.approx(sqrt(2))
     assert g(2) == pytest.approx(-sqrt(2))
     assert g(3) == pytest.approx(-sqrt(2))
+
+    g = Step(1) * Sinusoid(0.5, 0.25)
+    assert g(0) == 0
+    assert g(1) == pytest.approx(sqrt(2) / 2)
+    assert g(2) == pytest.approx(-sqrt(2) / 2)
+    assert g(3) == pytest.approx(-sqrt(2) / 2)
+
+
+def test_division():
+    g = Step() / 2
+    assert g(-1) == 0.0
+    assert g(0) == 0.5
+    assert g(1) == 0.5
+    assert g(2) == 0.5
+
+    g = 0.5 / Step()
+    assert g(0) == 0.5
+    assert g(1) == 0.5
+    assert g(2) == 0.5
+
+    g = Sinusoid(0.5, 0.25) / (0.5 * Step())
+    assert g(0) == pytest.approx(sqrt(2))
+    assert g(1) == pytest.approx(sqrt(2))
+    assert g(2) == pytest.approx(-sqrt(2))
+    assert g(3) == pytest.approx(-sqrt(2))
-- 
GitLab