diff --git a/python-api-src/lib_sc2_typeenums.cpp b/python-api-src/lib_sc2_typeenums.cpp index 03ae7992ff11558485830f52ebb7d2aa7c934059..235056ce29cc1711fc5297aa93dd85ab90013fc3 100644 --- a/python-api-src/lib_sc2_typeenums.cpp +++ b/python-api-src/lib_sc2_typeenums.cpp @@ -217,7 +217,8 @@ void define_typeenums(py::module & m) .value("NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE", sc2::UNIT_TYPEID::NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE) .value("NEUTRAL_UTILITYBOT", sc2::UNIT_TYPEID::NEUTRAL_UTILITYBOT) .value("NEUTRAL_VESPENEGEYSER", sc2::UNIT_TYPEID::NEUTRAL_VESPENEGEYSER) - .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER); + .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER) + .def("__eq__", [](const sc2::UNIT_TYPEID &value, sc2::UnitTypeID &value2) { return value == value2; }); py::enum_<sc2::ABILITY_ID>(m, "ABILITY_ID") .value("INVALID", sc2::ABILITY_ID::INVALID) diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp index dd1bc0cfd0cd945da04754358c7214bfcd48b487..5119445973792543ebcfde64533ce62c346e9fba 100644 --- a/python-api-src/library.cpp +++ b/python-api-src/library.cpp @@ -50,7 +50,8 @@ PYBIND11_MODULE(library, m) py::class_<sc2::UnitTypeID>(m, "UnitTypeID") - .def(py::init<sc2::UNIT_TYPEID>()); + .def(py::init<sc2::UNIT_TYPEID>()) + .def("__eq__", [](const sc2::UnitTypeID &value, sc2::UNIT_TYPEID &value2) { return value == value2; }); py::implicitly_convertible<sc2::UNIT_TYPEID, sc2::UnitTypeID>(); diff --git a/tests/unittypeid.py b/tests/unittypeid.py new file mode 100644 index 0000000000000000000000000000000000000000..f3815d9002078987949f31beac2dec8cb93123b3 --- /dev/null +++ b/tests/unittypeid.py @@ -0,0 +1,25 @@ +import unittest +import sys + +sys.path.append('build/python-api-src') + +from library import UnitTypeID, UNIT_TYPEID + + +class TestUnitType(unittest.TestCase): + + def test_equality(self): + self.assertTrue(UNIT_TYPEID.TERRAN_ARMORY == UNIT_TYPEID.TERRAN_ARMORY) + + def test_inequality(self): + self.assertFalse(UNIT_TYPEID.TERRAN_ARMORY != UNIT_TYPEID.TERRAN_ARMORY) + + def test_convert_equality(self): + unit_typeid = UNIT_TYPEID.TERRAN_ARMORY + self.assertTrue(unit_typeid == UnitTypeID(unit_typeid)) + self.assertTrue(UnitTypeID(unit_typeid) == unit_typeid) + self.assertFalse(UnitTypeID(unit_typeid) != unit_typeid) + +if __name__ == '__main__': + unittest.main() +