From 911d46ce91fe16096f62a07ca97fe9a1b0d65f2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Bergstr=C3=B6m?= <david.bergstrom@liu.se> Date: Thu, 21 Nov 2019 17:07:46 +0100 Subject: [PATCH] Make it possible to compare UnitTypeID with UNIT_TYPEID --- python-api-src/lib_sc2_typeenums.cpp | 3 ++- python-api-src/library.cpp | 3 ++- tests/unittypeid.py | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 tests/unittypeid.py diff --git a/python-api-src/lib_sc2_typeenums.cpp b/python-api-src/lib_sc2_typeenums.cpp index 03ae7992f..235056ce2 100644 --- a/python-api-src/lib_sc2_typeenums.cpp +++ b/python-api-src/lib_sc2_typeenums.cpp @@ -217,7 +217,8 @@ void define_typeenums(py::module & m) .value("NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE", sc2::UNIT_TYPEID::NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE) .value("NEUTRAL_UTILITYBOT", sc2::UNIT_TYPEID::NEUTRAL_UTILITYBOT) .value("NEUTRAL_VESPENEGEYSER", sc2::UNIT_TYPEID::NEUTRAL_VESPENEGEYSER) - .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER); + .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER) + .def("__eq__", [](const sc2::UNIT_TYPEID &value, sc2::UnitTypeID &value2) { return value == value2; }); py::enum_<sc2::ABILITY_ID>(m, "ABILITY_ID") .value("INVALID", sc2::ABILITY_ID::INVALID) diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp index dd1bc0cfd..511944597 100644 --- a/python-api-src/library.cpp +++ b/python-api-src/library.cpp @@ -50,7 +50,8 @@ PYBIND11_MODULE(library, m) py::class_<sc2::UnitTypeID>(m, "UnitTypeID") - .def(py::init<sc2::UNIT_TYPEID>()); + .def(py::init<sc2::UNIT_TYPEID>()) + .def("__eq__", [](const sc2::UnitTypeID &value, sc2::UNIT_TYPEID &value2) { return value == value2; }); py::implicitly_convertible<sc2::UNIT_TYPEID, sc2::UnitTypeID>(); diff --git a/tests/unittypeid.py b/tests/unittypeid.py new file mode 100644 index 000000000..f3815d900 --- /dev/null +++ b/tests/unittypeid.py @@ -0,0 +1,25 @@ +import unittest +import sys + +sys.path.append('build/python-api-src') + +from library import UnitTypeID, UNIT_TYPEID + + +class TestUnitType(unittest.TestCase): + + def test_equality(self): + self.assertTrue(UNIT_TYPEID.TERRAN_ARMORY == UNIT_TYPEID.TERRAN_ARMORY) + + def test_inequality(self): + self.assertFalse(UNIT_TYPEID.TERRAN_ARMORY != UNIT_TYPEID.TERRAN_ARMORY) + + def test_convert_equality(self): + unit_typeid = UNIT_TYPEID.TERRAN_ARMORY + self.assertTrue(unit_typeid == UnitTypeID(unit_typeid)) + self.assertTrue(UnitTypeID(unit_typeid) == unit_typeid) + self.assertFalse(UnitTypeID(unit_typeid) != unit_typeid) + +if __name__ == '__main__': + unittest.main() + -- GitLab