From 911d46ce91fe16096f62a07ca97fe9a1b0d65f2d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Bergstr=C3=B6m?= <david.bergstrom@liu.se>
Date: Thu, 21 Nov 2019 17:07:46 +0100
Subject: [PATCH] Make it possible to compare UnitTypeID with UNIT_TYPEID

---
 python-api-src/lib_sc2_typeenums.cpp |  3 ++-
 python-api-src/library.cpp           |  3 ++-
 tests/unittypeid.py                  | 25 +++++++++++++++++++++++++
 3 files changed, 29 insertions(+), 2 deletions(-)
 create mode 100644 tests/unittypeid.py

diff --git a/python-api-src/lib_sc2_typeenums.cpp b/python-api-src/lib_sc2_typeenums.cpp
index 03ae7992f..235056ce2 100644
--- a/python-api-src/lib_sc2_typeenums.cpp
+++ b/python-api-src/lib_sc2_typeenums.cpp
@@ -217,7 +217,8 @@ void define_typeenums(py::module & m)
         .value("NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE", sc2::UNIT_TYPEID::NEUTRAL_UNBUILDABLEPLATESDESTRUCTIBLE)
         .value("NEUTRAL_UTILITYBOT", sc2::UNIT_TYPEID::NEUTRAL_UTILITYBOT)
         .value("NEUTRAL_VESPENEGEYSER", sc2::UNIT_TYPEID::NEUTRAL_VESPENEGEYSER)
-        .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER);
+        .value("NEUTRAL_XELNAGATOWER", sc2::UNIT_TYPEID::NEUTRAL_XELNAGATOWER)
+        .def("__eq__", [](const sc2::UNIT_TYPEID &value, sc2::UnitTypeID &value2) { return value == value2; });
 
     py::enum_<sc2::ABILITY_ID>(m, "ABILITY_ID")
         .value("INVALID", sc2::ABILITY_ID::INVALID)
diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp
index dd1bc0cfd..511944597 100644
--- a/python-api-src/library.cpp
+++ b/python-api-src/library.cpp
@@ -50,7 +50,8 @@ PYBIND11_MODULE(library, m)
 
 
     py::class_<sc2::UnitTypeID>(m, "UnitTypeID")
-        .def(py::init<sc2::UNIT_TYPEID>());
+        .def(py::init<sc2::UNIT_TYPEID>())
+        .def("__eq__", [](const sc2::UnitTypeID &value, sc2::UNIT_TYPEID &value2) { return value == value2; });
 
     py::implicitly_convertible<sc2::UNIT_TYPEID, sc2::UnitTypeID>();
 
diff --git a/tests/unittypeid.py b/tests/unittypeid.py
new file mode 100644
index 000000000..f3815d900
--- /dev/null
+++ b/tests/unittypeid.py
@@ -0,0 +1,25 @@
+import unittest
+import sys
+
+sys.path.append('build/python-api-src')
+
+from library import UnitTypeID, UNIT_TYPEID
+
+
+class TestUnitType(unittest.TestCase):
+
+    def test_equality(self):
+        self.assertTrue(UNIT_TYPEID.TERRAN_ARMORY == UNIT_TYPEID.TERRAN_ARMORY)
+    
+    def test_inequality(self):
+        self.assertFalse(UNIT_TYPEID.TERRAN_ARMORY != UNIT_TYPEID.TERRAN_ARMORY)
+
+    def test_convert_equality(self):
+        unit_typeid = UNIT_TYPEID.TERRAN_ARMORY
+        self.assertTrue(unit_typeid == UnitTypeID(unit_typeid))
+        self.assertTrue(UnitTypeID(unit_typeid) == unit_typeid)
+        self.assertFalse(UnitTypeID(unit_typeid) != unit_typeid)
+
+if __name__ == '__main__':
+    unittest.main()
+
-- 
GitLab