diff --git a/python-api-src/lib_replay_unit.cpp b/python-api-src/lib_replay_unit.cpp index 0266b985462c810ce9c56ed7acabecec33f3339b..411b2e220a5af412434f6820b6823fce4796cb4b 100644 --- a/python-api-src/lib_replay_unit.cpp +++ b/python-api-src/lib_replay_unit.cpp @@ -4,6 +4,39 @@ namespace py = pybind11; void define_replay_unit(py::module & m) { - py::class_<UnitInformation>(m, "ReplayUnit") - .def_property_readonly("id", &Unit::getID); + py::class_<UnitInformation>(m, "ReplayUnit") + .def_property_readonly("id", &UnitInformation::getID) + .def_property_readonly("unit_type", &UnitInformation::getType, "The name of the type") + .def_property_readonly("position", &UnitInformation::getPosition, "The :class:`library.Point2D` of the unit") + .def_property_readonly("tile_position", &UnitInformation::getTilePosition, "The :class:`library.Point2DI` of the unit") + .def_property_readonly("hit_points", &UnitInformation::getHitPoints) + .def_property_readonly("shields", &UnitInformation::getShields) + .def_property_readonly("energy", &UnitInformation::getEnergy) + .def_property_readonly("player", &UnitInformation::getPlayer) + .def_property_readonly("build_percentage", &UnitInformation::getBuildPercentage) + .def_property_readonly("weapon_cooldown", &UnitInformation::getWeaponCooldown) + .def_property_readonly("is_completed", &UnitInformation::isCompleted) + .def_property_readonly("is_being_constructed", &UnitInformation::isBeingConstructed) + .def_property_readonly("is_cloaked", &UnitInformation::isCloaked) + .def_property_readonly("is_flying", &UnitInformation::isFlying) + .def_property_readonly("buffs", &UnitInformation::buffs) + .def_property_readonly("is_alive", &UnitInformation::isAlive) + .def_property_readonly("is_powered", &UnitInformation::isPowered) + .def_property_readonly("is_idle", &UnitInformation::isIdle) + .def_property_readonly("is_burrowed", &UnitInformation::isBurrowed) + .def_property_readonly("is_valid", &UnitInformation::isValid) + .def_property_readonly("is_training", &UnitInformation::isTraining) + .def_property_readonly("is_blip", &UnitInformation::isBlip) + .def_property_readonly("target", &UnitInformation::getTarget) + .def_property_readonly("has_target", &UnitInformation::hasTarget) + .def_property_readonly("max_hit_points", &UnitInformation::getMaxHitPoints) + .def_property_readonly("progress", &UnitInformation::getProgress) + .def_property_readonly("current_ability_id", &UnitInformation::getCurrentAbilityID, "The AbilityID of currently used ability") + .def_property_readonly("facing", &UnitInformation::getFacing) + .def_property_readonly("radius", &UnitInformation::getRadius) + .def_property_readonly("is_carrying_minerals", &UnitInformation::isCarryingMinerals) + .def("__hash__", [](const UnitInformation & unit) { return std::hash<const sc2::Unit *>{}(unit.getUnitPtr()); }) + .def(py::self == py::self) + .def("__repr__", [](const UnitInformation & unit) { return "<Unit of type: '" + unit.getType() + "'>"; }) + ; } diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp index 5e157b765474502671a42bd59a13d510572a3664..1ec48212a0fa9b30beef9c3901cd7cc8c87d01dd 100644 --- a/python-api-src/library.cpp +++ b/python-api-src/library.cpp @@ -98,7 +98,8 @@ PYBIND11_MODULE(library, m) .def("on_game_start", &IDAReplayObserver::OnGameStart) .def("on_step", &IDAReplayObserver::OnStep) .def("on_game_end", &IDAReplayObserver::OnGameEnd) - .def("get_all_units", &IDAReplayObserver::GetAllUnits, "Returns a list of all units"); + .def("get_all_units", &IDAReplayObserver::GetAllUnits, "Returns a list of all units") + ; py::class_<sc2::PlayerSetup>(m, "PlayerSetup"); diff --git a/python-api-src/library.h b/python-api-src/library.h index 805d5ffc13ea1824694307862046d28f1973ab78..daee1f3fec1c0801a41b6a65810721f170a141a7 100644 --- a/python-api-src/library.h +++ b/python-api-src/library.h @@ -85,6 +85,16 @@ public: OnStep ); } + void OnGameEnd() override + { + PYBIND11_OVERLOAD_NAME( + void, + IDAReplayObserver, + "on_game_end", + OnGameEnd + + ); + } }; // The functions below are all defined in different .cpp files, in order diff --git a/src/IDAReplayObserver.cpp b/src/IDAReplayObserver.cpp index 75cfa061f845f1615536a9133308a45c5eeee83b..21e20e456c302ba407cf499e83ffca3b1b4ede93 100644 --- a/src/IDAReplayObserver.cpp +++ b/src/IDAReplayObserver.cpp @@ -29,12 +29,23 @@ void IDAReplayObserver::OnGameStart() void IDAReplayObserver::OnStep() { - setUnits(); - +} + +void IDAReplayObserver::OnGameEnd() +{ +} +UnitInformation IDAReplayObserver::GetUnit(const CCUnitID tag) const +{ + return UnitInformation(Observation()->GetUnit(tag), *(IDAReplayObserver *)this); } + + + + + const std::vector<UnitInformation>& IDAReplayObserver::GetAllUnits() const { diff --git a/src/IDAReplayObserver.h b/src/IDAReplayObserver.h index 88ca4e032480ee90cb152549a81862b9db5cbe43..de78d6b52274aa55a20bd58fdbf41da77a5f99f4 100644 --- a/src/IDAReplayObserver.h +++ b/src/IDAReplayObserver.h @@ -14,15 +14,18 @@ class IDAReplayObserver : public sc2::ReplayObserver void setUnits(); - + std::vector<UnitInformation> m_allUnits; public: IDAReplayObserver(); - std::vector<UnitInformation> m_allUnits; + void OnGameStart() override; void OnStep() override; + void OnGameEnd() override; + UnitInformation GetUnit(const CCUnitID tag) const; + const std::vector<UnitInformation> & GetAllUnits() const; diff --git a/src/Unit.h b/src/Unit.h index 2f7f28067e815678351323339cea873cc82f98e1..24a3d631624e4eaf1001431f85665b85396b9513 100644 --- a/src/Unit.h +++ b/src/Unit.h @@ -11,7 +11,8 @@ class Unit CCUnitID m_unitID; UnitType m_unitType; - const sc2::Unit * m_unit; +protected: + const sc2::Unit * m_unit; public: diff --git a/src/UnitInformation.cpp b/src/UnitInformation.cpp index 438682a9d616ff8e91fe2a7cd485bd13dafa4327..ab42e9f2bea653b00b833821541b313c1131fa97 100644 --- a/src/UnitInformation.cpp +++ b/src/UnitInformation.cpp @@ -8,4 +8,45 @@ UnitInformation::UnitInformation(const sc2::Unit * unit, IDAReplayObserver & rep } + std::string UnitInformation::getType() const +{ + return m_unit->unit_type.to_string(); + +} + +bool UnitInformation::hasTarget() const +{ + BOT_ASSERT(isValid(), "Unit is not valid"); + + if (getUnitPtr()->orders.size() > 0) { + if (getUnitPtr()->orders[0].target_unit_tag != NULL) { + CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag; + // IDABot finds the unit with this tag, and returns true if valid + return m_replayObserver->GetUnit(t_id).isValid(); + } + } + + return false; +} + +UnitInformation UnitInformation::getTarget() const +{ + BOT_ASSERT(isValid(), "Unit is not valid"); + + // if unit has order, check tag of target of first order + if (getUnitPtr()->orders.size() > 0) { + // t_id is set to the unit tag of the target + CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag; + //The tag is for somereason a null tag + if (t_id == sc2::NullTag) { + return *this; + } + // IDAReplayObserver finds the unit with this tag + return m_replayObserver->GetUnit(t_id); + } + + UnitInformation this_unit = UnitInformation(m_unit, *m_replayObserver); + return this_unit; +} + diff --git a/src/UnitInformation.h b/src/UnitInformation.h index 85510036f9a6ed2cf5bb9ad9ad72b1e5b57cd959..86a7656a65689127c04434acc5acb19c6314d0e7 100644 --- a/src/UnitInformation.h +++ b/src/UnitInformation.h @@ -8,9 +8,13 @@ class UnitInformation: public Unit { mutable IDAReplayObserver * m_replayObserver; + + public: UnitInformation(const sc2::Unit * unit, IDAReplayObserver & replayObserver); - const UnitType & getType() const; + std::string getType() const; + bool hasTarget() const; + UnitInformation getTarget() const; }; \ No newline at end of file