From 58809f9e333977f8f8cab1e3ccb86e23f77c4d66 Mon Sep 17 00:00:00 2001
From: Rojikku98 <be.edvin@gmail.com>
Date: Sun, 26 Jul 2020 14:10:22 +0200
Subject: [PATCH] Added information for units

---
 python-api-src/lib_replay_unit.cpp | 37 +++++++++++++++++++++++++--
 python-api-src/library.cpp         |  3 ++-
 python-api-src/library.h           | 10 ++++++++
 src/IDAReplayObserver.cpp          | 15 +++++++++--
 src/IDAReplayObserver.h            |  7 +++--
 src/Unit.h                         |  3 ++-
 src/UnitInformation.cpp            | 41 ++++++++++++++++++++++++++++++
 src/UnitInformation.h              |  6 ++++-
 8 files changed, 113 insertions(+), 9 deletions(-)

diff --git a/python-api-src/lib_replay_unit.cpp b/python-api-src/lib_replay_unit.cpp
index 0266b9854..411b2e220 100644
--- a/python-api-src/lib_replay_unit.cpp
+++ b/python-api-src/lib_replay_unit.cpp
@@ -4,6 +4,39 @@ namespace py = pybind11;
 
 void define_replay_unit(py::module & m)
 {
-    py::class_<UnitInformation>(m, "ReplayUnit")
-        .def_property_readonly("id", &Unit::getID);
+	py::class_<UnitInformation>(m, "ReplayUnit")
+		.def_property_readonly("id", &UnitInformation::getID)
+		.def_property_readonly("unit_type", &UnitInformation::getType, "The name of the type")
+		.def_property_readonly("position", &UnitInformation::getPosition, "The :class:`library.Point2D` of the unit")
+		.def_property_readonly("tile_position", &UnitInformation::getTilePosition, "The :class:`library.Point2DI` of the unit")
+		.def_property_readonly("hit_points", &UnitInformation::getHitPoints)
+		.def_property_readonly("shields", &UnitInformation::getShields)
+		.def_property_readonly("energy", &UnitInformation::getEnergy)
+		.def_property_readonly("player", &UnitInformation::getPlayer)
+		.def_property_readonly("build_percentage", &UnitInformation::getBuildPercentage)
+		.def_property_readonly("weapon_cooldown", &UnitInformation::getWeaponCooldown)
+		.def_property_readonly("is_completed", &UnitInformation::isCompleted)
+		.def_property_readonly("is_being_constructed", &UnitInformation::isBeingConstructed)
+		.def_property_readonly("is_cloaked", &UnitInformation::isCloaked)
+		.def_property_readonly("is_flying", &UnitInformation::isFlying)
+		.def_property_readonly("buffs", &UnitInformation::buffs)
+		.def_property_readonly("is_alive", &UnitInformation::isAlive)
+		.def_property_readonly("is_powered", &UnitInformation::isPowered)
+		.def_property_readonly("is_idle", &UnitInformation::isIdle)
+		.def_property_readonly("is_burrowed", &UnitInformation::isBurrowed)
+		.def_property_readonly("is_valid", &UnitInformation::isValid)
+		.def_property_readonly("is_training", &UnitInformation::isTraining)
+		.def_property_readonly("is_blip", &UnitInformation::isBlip)
+		.def_property_readonly("target", &UnitInformation::getTarget)
+		.def_property_readonly("has_target", &UnitInformation::hasTarget)
+		.def_property_readonly("max_hit_points", &UnitInformation::getMaxHitPoints)
+		.def_property_readonly("progress", &UnitInformation::getProgress)
+		.def_property_readonly("current_ability_id", &UnitInformation::getCurrentAbilityID, "The AbilityID of currently used ability")
+		.def_property_readonly("facing", &UnitInformation::getFacing)
+		.def_property_readonly("radius", &UnitInformation::getRadius)
+		.def_property_readonly("is_carrying_minerals", &UnitInformation::isCarryingMinerals)
+		.def("__hash__", [](const UnitInformation & unit) { return std::hash<const sc2::Unit *>{}(unit.getUnitPtr()); })
+		.def(py::self == py::self)
+		.def("__repr__", [](const UnitInformation & unit) { return "<Unit of type: '" + unit.getType() + "'>"; })
+		;
 }
diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp
index 5e157b765..1ec48212a 100644
--- a/python-api-src/library.cpp
+++ b/python-api-src/library.cpp
@@ -98,7 +98,8 @@ PYBIND11_MODULE(library, m)
 		.def("on_game_start", &IDAReplayObserver::OnGameStart)
 		.def("on_step", &IDAReplayObserver::OnStep)
 		.def("on_game_end", &IDAReplayObserver::OnGameEnd)
-		.def("get_all_units", &IDAReplayObserver::GetAllUnits, "Returns a list of all units");
+		.def("get_all_units", &IDAReplayObserver::GetAllUnits, "Returns a list of all units")
+		;
 
     py::class_<sc2::PlayerSetup>(m, "PlayerSetup");
 
diff --git a/python-api-src/library.h b/python-api-src/library.h
index 805d5ffc1..daee1f3fe 100644
--- a/python-api-src/library.h
+++ b/python-api-src/library.h
@@ -85,6 +85,16 @@ public:
 			OnStep
 		);
 	}
+	void OnGameEnd() override
+	{
+		PYBIND11_OVERLOAD_NAME(
+			void,
+			IDAReplayObserver,
+			"on_game_end",
+			OnGameEnd
+
+		);
+	}
 };
 
 // The functions below are all defined in different .cpp files, in order 
diff --git a/src/IDAReplayObserver.cpp b/src/IDAReplayObserver.cpp
index 75cfa061f..21e20e456 100644
--- a/src/IDAReplayObserver.cpp
+++ b/src/IDAReplayObserver.cpp
@@ -29,12 +29,23 @@ void IDAReplayObserver::OnGameStart()
 
 void IDAReplayObserver::OnStep()
 {
-	
 	setUnits();
-	
+}
+
+void IDAReplayObserver::OnGameEnd()
+{
+}
 
+UnitInformation IDAReplayObserver::GetUnit(const CCUnitID tag) const
+{		
+	return UnitInformation(Observation()->GetUnit(tag), *(IDAReplayObserver *)this);
 }
 
+
+
+
+
+
 const std::vector<UnitInformation>& IDAReplayObserver::GetAllUnits() const
 {
 
diff --git a/src/IDAReplayObserver.h b/src/IDAReplayObserver.h
index 88ca4e032..de78d6b52 100644
--- a/src/IDAReplayObserver.h
+++ b/src/IDAReplayObserver.h
@@ -14,15 +14,18 @@ class IDAReplayObserver : public sc2::ReplayObserver
 	
 
 	void setUnits();
-
+	std::vector<UnitInformation>       m_allUnits;
 
 public:
 	IDAReplayObserver();
 
-	std::vector<UnitInformation>       m_allUnits;
+	
 
 	void OnGameStart() override;
 	void OnStep() override;
+	void OnGameEnd() override;
+	UnitInformation GetUnit(const CCUnitID tag) const;
+
 
 	const std::vector<UnitInformation> & GetAllUnits() const;
 
diff --git a/src/Unit.h b/src/Unit.h
index 2f7f28067..24a3d6316 100644
--- a/src/Unit.h
+++ b/src/Unit.h
@@ -11,7 +11,8 @@ class Unit
     CCUnitID    m_unitID;
     UnitType    m_unitType;
 
-    const sc2::Unit * m_unit;
+protected:   
+	const sc2::Unit * m_unit;
 
 public:
 
diff --git a/src/UnitInformation.cpp b/src/UnitInformation.cpp
index 438682a9d..ab42e9f2b 100644
--- a/src/UnitInformation.cpp
+++ b/src/UnitInformation.cpp
@@ -8,4 +8,45 @@ UnitInformation::UnitInformation(const sc2::Unit * unit, IDAReplayObserver & rep
 	
 }
 
+ std::string UnitInformation::getType() const
+{
+	 return m_unit->unit_type.to_string();
+		
+}
+
+bool UnitInformation::hasTarget() const
+{
+	BOT_ASSERT(isValid(), "Unit is not valid");
+
+	if (getUnitPtr()->orders.size() > 0) {
+		if (getUnitPtr()->orders[0].target_unit_tag != NULL) {
+			CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag;
+			// IDABot finds the unit with this tag, and returns true if valid
+			return m_replayObserver->GetUnit(t_id).isValid();
+		}
+	}
+
+	return false;
+}
+
+UnitInformation UnitInformation::getTarget() const
+{
+	BOT_ASSERT(isValid(), "Unit is not valid");
+
+	// if unit has order, check tag of target of first order
+	if (getUnitPtr()->orders.size() > 0) {
+		// t_id is set to the unit tag of the target
+		CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag;
+		//The tag is for somereason a null tag
+		if (t_id == sc2::NullTag) {
+			return *this;
+		}
+		// IDAReplayObserver finds the unit with this tag
+		return m_replayObserver->GetUnit(t_id);
+	}
+
+	UnitInformation this_unit = UnitInformation(m_unit, *m_replayObserver);
+	return this_unit;
+}
+
 
diff --git a/src/UnitInformation.h b/src/UnitInformation.h
index 85510036f..86a7656a6 100644
--- a/src/UnitInformation.h
+++ b/src/UnitInformation.h
@@ -8,9 +8,13 @@ class UnitInformation: public Unit
 {
 	mutable IDAReplayObserver * m_replayObserver;
 
+
+
 public:
 		UnitInformation(const sc2::Unit * unit, IDAReplayObserver & replayObserver);
 
-		const UnitType & getType() const;
+		std::string getType() const;
+		bool hasTarget() const;
+		UnitInformation getTarget() const;
 
 };
\ No newline at end of file
-- 
GitLab