Skip to content
Snippets Groups Projects
Commit 2ead0284 authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Make Add/Sub/Mul/DivGenerator private and improve type checking

parent ec515e35
No related branches found
No related tags found
1 merge request!167Make Add/Sub/Mul/DivGenerator private and improve type checking
Pipeline #88975 passed
...@@ -19,45 +19,53 @@ class SignalGenerator: ...@@ -19,45 +19,53 @@ class SignalGenerator:
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
raise NotImplementedError raise NotImplementedError
def __add__(self, other) -> "AddGenerator": def __add__(self, other) -> "_AddGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return AddGenerator(self, Constant(other)) return _AddGenerator(self, Constant(other))
return AddGenerator(self, other) if isinstance(other, SignalGenerator):
return _AddGenerator(self, other)
raise TypeError(f"Cannot add {other!r} to {type(self)}")
def __radd__(self, other) -> "AddGenerator": def __radd__(self, other) -> "_AddGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return AddGenerator(self, Constant(other)) return _AddGenerator(Constant(other), self)
return AddGenerator(self, other) raise TypeError(f"Cannot add {type(self)} to {other!r}")
def __sub__(self, other) -> "SubGenerator": def __sub__(self, other) -> "_SubGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return SubGenerator(self, Constant(other)) return _SubGenerator(self, Constant(other))
return SubGenerator(self, other) if isinstance(other, SignalGenerator):
return _SubGenerator(self, other)
raise TypeError(f"Cannot subtract {other!r} from {type(self)}")
def __rsub__(self, other) -> "SubGenerator": def __rsub__(self, other) -> "_SubGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return SubGenerator(Constant(other), self) return _SubGenerator(Constant(other), self)
return SubGenerator(other, self) raise TypeError(f"Cannot subtract {type(self)} from {other!r}")
def __mul__(self, other) -> "MulGenerator": def __mul__(self, other) -> "_MulGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return MultGenerator(self, Constant(other)) return _MulGenerator(self, Constant(other))
return MultGenerator(self, other) if isinstance(other, SignalGenerator):
return _MulGenerator(self, other)
raise TypeError(f"Cannot multiply {type(self)} with {other!r}")
def __rmul__(self, other) -> "MulGenerator": def __rmul__(self, other) -> "_MulGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return MultGenerator(self, Constant(other)) return _MulGenerator(Constant(other), self)
return MultGenerator(self, other) raise TypeError(f"Cannot multiply {other!r} with {type(self)}")
def __truediv__(self, other) -> "MulGenerator": def __truediv__(self, other) -> "_DivGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return DivGenerator(self, Constant(other)) return _DivGenerator(self, Constant(other))
return DivGenerator(self, other) if isinstance(other, SignalGenerator):
return _DivGenerator(self, other)
raise TypeError(f"Cannot divide {type(self)} with {other!r}")
def __rtruediv__(self, other) -> "MulGenerator": def __rtruediv__(self, other) -> "_DivGenerator":
if isinstance(other, Number): if isinstance(other, Number):
return DivGenerator(Constant(other), self) return _DivGenerator(Constant(other), self)
return DivGenerator(other, self) raise TypeError(f"Cannot divide {other!r} with {type(self)}")
class Impulse(SignalGenerator): class Impulse(SignalGenerator):
...@@ -70,7 +78,7 @@ class Impulse(SignalGenerator): ...@@ -70,7 +78,7 @@ class Impulse(SignalGenerator):
The delay before the signal goes to 1 for one sample. The delay before the signal goes to 1 for one sample.
""" """
def __init__(self, delay: int = 0) -> Callable[[int], complex]: def __init__(self, delay: int = 0) -> None:
self._delay = delay self._delay = delay
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
...@@ -90,7 +98,7 @@ class Step(SignalGenerator): ...@@ -90,7 +98,7 @@ class Step(SignalGenerator):
The delay before the signal goes to 1. The delay before the signal goes to 1.
""" """
def __init__(self, delay: int = 0) -> Callable[[int], complex]: def __init__(self, delay: int = 0) -> None:
self._delay = delay self._delay = delay
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
...@@ -110,7 +118,7 @@ class Constant(SignalGenerator): ...@@ -110,7 +118,7 @@ class Constant(SignalGenerator):
The constant. The constant.
""" """
def __init__(self, constant: complex = 1.0) -> Callable[[int], complex]: def __init__(self, constant: complex = 1.0) -> None:
self._constant = constant self._constant = constant
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
...@@ -130,7 +138,7 @@ class ZeroPad(SignalGenerator): ...@@ -130,7 +138,7 @@ class ZeroPad(SignalGenerator):
The data that should be padded. The data that should be padded.
""" """
def __init__(self, data: Sequence[complex]) -> Callable[[int], complex]: def __init__(self, data: Sequence[complex]) -> None:
self._data = data self._data = data
self._len = len(data) self._len = len(data)
...@@ -156,9 +164,7 @@ class Sinusoid(SignalGenerator): ...@@ -156,9 +164,7 @@ class Sinusoid(SignalGenerator):
The normalized phase offset. The normalized phase offset.
""" """
def __init__( def __init__(self, frequency: float, phase: float = 0.0) -> None:
self, frequency: float, phase: float = 0.0
) -> Callable[[int], complex]:
self._frequency = frequency self._frequency = frequency
self._phase = phase self._phase = phase
...@@ -173,95 +179,87 @@ class Sinusoid(SignalGenerator): ...@@ -173,95 +179,87 @@ class Sinusoid(SignalGenerator):
) )
class AddGenerator: class _AddGenerator(SignalGenerator):
""" """
Signal generator that adds two signals. Signal generator that adds two signals.
""" """
def __init__( def __init__(self, a: SignalGenerator, b: SignalGenerator) -> None:
self, a: SignalGenerator, b: SignalGenerator
) -> Callable[[int], complex]:
self._a = a self._a = a
self._b = b self._b = b
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
return self._a(time) + self._b(time) return self._a(time) + self._b(time)
def __str__(self): def __repr__(self):
return f"{self._a} + {self._b}" return f"{self._a} + {self._b}"
class SubGenerator: class _SubGenerator(SignalGenerator):
""" """
Signal generator that subtracts two signals. Signal generator that subtracts two signals.
""" """
def __init__( def __init__(self, a: SignalGenerator, b: SignalGenerator) -> None:
self, a: SignalGenerator, b: SignalGenerator
) -> Callable[[int], complex]:
self._a = a self._a = a
self._b = b self._b = b
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
return self._a(time) - self._b(time) return self._a(time) - self._b(time)
def __str__(self): def __repr__(self):
return f"{self._a} - {self._b}" return f"{self._a} - {self._b}"
class MultGenerator: class _MulGenerator(SignalGenerator):
""" """
Signal generator that multiplies two signals. Signal generator that multiplies two signals.
""" """
def __init__( def __init__(self, a: SignalGenerator, b: SignalGenerator) -> None:
self, a: SignalGenerator, b: SignalGenerator
) -> Callable[[int], complex]:
self._a = a self._a = a
self._b = b self._b = b
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
return self._a(time) * self._b(time) return self._a(time) * self._b(time)
def __str__(self): def __repr__(self):
a = ( a = (
f"({self._a})" f"({self._a})"
if isinstance(self._a, (AddGenerator, SubGenerator)) if isinstance(self._a, (_AddGenerator, _SubGenerator))
else f"{self._a}" else f"{self._a}"
) )
b = ( b = (
f"({self._b})" f"({self._b})"
if isinstance(self._b, (AddGenerator, SubGenerator)) if isinstance(self._b, (_AddGenerator, _SubGenerator))
else f"{self._b}" else f"{self._b}"
) )
return f"{a} * {b}" return f"{a} * {b}"
class DivGenerator: class _DivGenerator(SignalGenerator):
""" """
Signal generator that divides two signals. Signal generator that divides two signals.
""" """
def __init__( def __init__(self, a: SignalGenerator, b: SignalGenerator) -> None:
self, a: SignalGenerator, b: SignalGenerator
) -> Callable[[int], complex]:
self._a = a self._a = a
self._b = b self._b = b
def __call__(self, time: int) -> complex: def __call__(self, time: int) -> complex:
return self._a(time) / self._b(time) return self._a(time) / self._b(time)
def __str__(self): def __repr__(self):
a = ( a = (
f"({self._a})" f"({self._a})"
if isinstance(self._a, (AddGenerator, SubGenerator)) if isinstance(self._a, (_AddGenerator, _SubGenerator))
else f"{self._a}" else f"{self._a}"
) )
b = ( b = (
f"({self._b})" f"({self._b})"
if isinstance( if isinstance(
self._b, self._b,
(AddGenerator, SubGenerator, MultGenerator, DivGenerator), (_AddGenerator, _SubGenerator, _MulGenerator, _DivGenerator),
) )
else f"{self._b}" else f"{self._b}"
) )
......
...@@ -2,7 +2,17 @@ from math import sqrt ...@@ -2,7 +2,17 @@ from math import sqrt
import pytest import pytest
from b_asic.signal_generator import Constant, Impulse, Sinusoid, Step, ZeroPad from b_asic.signal_generator import (
Constant,
Impulse,
Sinusoid,
Step,
ZeroPad,
_AddGenerator,
_DivGenerator,
_MulGenerator,
_SubGenerator,
)
def test_impulse(): def test_impulse():
...@@ -96,6 +106,7 @@ def test_addition(): ...@@ -96,6 +106,7 @@ def test_addition():
assert g(3) == 0 assert g(3) == 0
assert str(g) == "Impulse() + Impulse(2)" assert str(g) == "Impulse() + Impulse(2)"
assert isinstance(g, _AddGenerator)
g = 1.0 + Impulse(2) g = 1.0 + Impulse(2)
assert g(-1) == 1 assert g(-1) == 1
...@@ -104,7 +115,8 @@ def test_addition(): ...@@ -104,7 +115,8 @@ def test_addition():
assert g(2) == 2 assert g(2) == 2
assert g(3) == 1 assert g(3) == 1
assert str(g) == "Impulse(2) + 1.0" assert str(g) == "1.0 + Impulse(2)"
assert isinstance(g, _AddGenerator)
g = Impulse(1) + 1.0 g = Impulse(1) + 1.0
assert g(-1) == 1 assert g(-1) == 1
...@@ -114,6 +126,7 @@ def test_addition(): ...@@ -114,6 +126,7 @@ def test_addition():
assert g(3) == 1 assert g(3) == 1
assert str(g) == "Impulse(1) + 1.0" assert str(g) == "Impulse(1) + 1.0"
assert isinstance(g, _AddGenerator)
def test_subtraction(): def test_subtraction():
...@@ -125,6 +138,7 @@ def test_subtraction(): ...@@ -125,6 +138,7 @@ def test_subtraction():
assert g(3) == 0 assert g(3) == 0
assert str(g) == "Impulse() - Impulse(2)" assert str(g) == "Impulse() - Impulse(2)"
assert isinstance(g, _SubGenerator)
g = 1.0 - Impulse(2) g = 1.0 - Impulse(2)
assert g(-1) == 1 assert g(-1) == 1
...@@ -134,6 +148,7 @@ def test_subtraction(): ...@@ -134,6 +148,7 @@ def test_subtraction():
assert g(3) == 1 assert g(3) == 1
assert str(g) == "1.0 - Impulse(2)" assert str(g) == "1.0 - Impulse(2)"
assert isinstance(g, _SubGenerator)
g = Impulse(2) - 1.0 g = Impulse(2) - 1.0
assert g(-1) == -1 assert g(-1) == -1
...@@ -143,6 +158,7 @@ def test_subtraction(): ...@@ -143,6 +158,7 @@ def test_subtraction():
assert g(3) == -1 assert g(3) == -1
assert str(g) == "Impulse(2) - 1.0" assert str(g) == "Impulse(2) - 1.0"
assert isinstance(g, _SubGenerator)
def test_multiplication(): def test_multiplication():
...@@ -153,6 +169,7 @@ def test_multiplication(): ...@@ -153,6 +169,7 @@ def test_multiplication():
assert g(2) == 0 assert g(2) == 0
assert str(g) == "Impulse() * 0.5" assert str(g) == "Impulse() * 0.5"
assert isinstance(g, _MulGenerator)
g = 2 * Sinusoid(0.5, 0.25) g = 2 * Sinusoid(0.5, 0.25)
assert g(0) == pytest.approx(sqrt(2)) assert g(0) == pytest.approx(sqrt(2))
...@@ -160,7 +177,8 @@ def test_multiplication(): ...@@ -160,7 +177,8 @@ def test_multiplication():
assert g(2) == pytest.approx(-sqrt(2)) assert g(2) == pytest.approx(-sqrt(2))
assert g(3) == pytest.approx(-sqrt(2)) assert g(3) == pytest.approx(-sqrt(2))
assert str(g) == "Sinusoid(0.5, 0.25) * 2" assert str(g) == "2 * Sinusoid(0.5, 0.25)"
assert isinstance(g, _MulGenerator)
g = Step(1) * (Sinusoid(0.5, 0.25) + 1.0) g = Step(1) * (Sinusoid(0.5, 0.25) + 1.0)
assert g(0) == 0 assert g(0) == 0
...@@ -169,6 +187,7 @@ def test_multiplication(): ...@@ -169,6 +187,7 @@ def test_multiplication():
assert g(3) == pytest.approx(-sqrt(2) / 2 + 1) assert g(3) == pytest.approx(-sqrt(2) / 2 + 1)
assert str(g) == "Step(1) * (Sinusoid(0.5, 0.25) + 1.0)" assert str(g) == "Step(1) * (Sinusoid(0.5, 0.25) + 1.0)"
assert isinstance(g, _MulGenerator)
def test_division(): def test_division():
...@@ -179,6 +198,7 @@ def test_division(): ...@@ -179,6 +198,7 @@ def test_division():
assert g(2) == 0.5 assert g(2) == 0.5
assert str(g) == "Step() / 2" assert str(g) == "Step() / 2"
assert isinstance(g, _DivGenerator)
g = 0.5 / Step() g = 0.5 / Step()
assert g(0) == 0.5 assert g(0) == 0.5
...@@ -186,6 +206,7 @@ def test_division(): ...@@ -186,6 +206,7 @@ def test_division():
assert g(2) == 0.5 assert g(2) == 0.5
assert str(g) == "0.5 / Step()" assert str(g) == "0.5 / Step()"
assert isinstance(g, _DivGenerator)
g = Sinusoid(0.5, 0.25) / (0.5 * Step()) g = Sinusoid(0.5, 0.25) / (0.5 * Step())
assert g(0) == pytest.approx(sqrt(2)) assert g(0) == pytest.approx(sqrt(2))
...@@ -193,4 +214,5 @@ def test_division(): ...@@ -193,4 +214,5 @@ def test_division():
assert g(2) == pytest.approx(-sqrt(2)) assert g(2) == pytest.approx(-sqrt(2))
assert g(3) == pytest.approx(-sqrt(2)) assert g(3) == pytest.approx(-sqrt(2))
assert str(g) == "Sinusoid(0.5, 0.25) / (Step() * 0.5)" assert str(g) == "Sinusoid(0.5, 0.25) / (0.5 * Step())"
assert isinstance(g, _DivGenerator)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment