diff --git a/python-api-src/lib_replay_unit.cpp b/python-api-src/lib_replay_unit.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a70770dffa7c1738104f451276f1b6f9aaaf3bb --- /dev/null +++ b/python-api-src/lib_replay_unit.cpp @@ -0,0 +1,44 @@ +#include "library.h" + +namespace py = pybind11; + +void define_replay_unit(py::module & m) +{ + py::class_<ReplayUnit>(m, "ReplayUnit") + .def_property_readonly("id", &ReplayUnit::getID) + .def_property_readonly("unit_type", &ReplayUnit::getType, "The id of the type") + .def_property_readonly("unit_type_name", &ReplayUnit::getTypeName, "The name of the type") + .def_property_readonly("position", &ReplayUnit::getPosition, "The :class:`library.Point2D` of the unit") + .def_property_readonly("tile_position", &ReplayUnit::getTilePosition, "The :class:`library.Point2DI` of the unit") + .def_property_readonly("hit_points", &ReplayUnit::getHitPoints) + .def_property_readonly("shields", &ReplayUnit::getShields) + .def_property_readonly("energy", &ReplayUnit::getEnergy) + .def_property_readonly("player", &ReplayUnit::getPlayer) + .def_property_readonly("build_percentage", &ReplayUnit::getBuildPercentage) + .def_property_readonly("weapon_cooldown", &ReplayUnit::getWeaponCooldown) + .def_property_readonly("is_completed", &ReplayUnit::isCompleted) + .def_property_readonly("is_being_constructed", &ReplayUnit::isBeingConstructed) + .def_property_readonly("is_cloaked", &ReplayUnit::isCloaked) + .def_property_readonly("is_flying", &ReplayUnit::isFlying) + .def_property_readonly("buffs", &ReplayUnit::buffs) + .def_property_readonly("is_alive", &ReplayUnit::isAlive) + .def_property_readonly("is_powered", &ReplayUnit::isPowered) + .def_property_readonly("is_idle", &ReplayUnit::isIdle) + .def_property_readonly("is_burrowed", &ReplayUnit::isBurrowed) + .def_property_readonly("is_valid", &ReplayUnit::isValid) + .def_property_readonly("is_training", &ReplayUnit::isTraining) + .def_property_readonly("is_blip", &ReplayUnit::isBlip) + // Has target and target crashes if the target died in the same frame + //.def_property_readonly("target", &ReplayUnit::getTarget) + //.def_property_readonly("has_target", &ReplayUnit::hasTarget) + .def_property_readonly("max_hit_points", &ReplayUnit::getMaxHitPoints) + .def_property_readonly("progress", &ReplayUnit::getProgress) + .def_property_readonly("current_ability_id", &ReplayUnit::getCurrentAbilityID, "The AbilityID of currently used ability") + .def_property_readonly("facing", &ReplayUnit::getFacing) + .def_property_readonly("radius", &ReplayUnit::getRadius) + .def_property_readonly("is_carrying_minerals", &ReplayUnit::isCarryingMinerals) + .def("__hash__", [](const ReplayUnit & unit) { return std::hash<const sc2::Unit *>{}(unit.getUnitPtr()); }) + .def(py::self == py::self) + .def("__repr__", [](const ReplayUnit & unit) { return "<Unit of type: '" + unit.getTypeName() + "'>"; }) + ; +} diff --git a/python-api-src/library.cpp b/python-api-src/library.cpp index b1818c27be7737444af08bed9406b16407132574..1ec48212a0fa9b30beef9c3901cd7cc8c87d01dd 100644 --- a/python-api-src/library.cpp +++ b/python-api-src/library.cpp @@ -8,6 +8,7 @@ PYBIND11_MODULE(library, m) define_typeenums(m); define_unit(m); + define_replay_unit(m); define_unittype(m); define_util(m); define_point(m); @@ -30,12 +31,6 @@ PYBIND11_MODULE(library, m) .def("load_replay_list",&sc2::Coordinator::SetReplayPath, "replay_path"_a) .def("add_replay_observer",&sc2::Coordinator::AddReplayObserver, "replay_observer"_a); - py::class_<sc2::ReplayObserver, PyReplayObserver>(m, "ReplayObserver") - .def(py::init()) - .def("on_game_start", &sc2::ReplayObserver::OnGameStart) - .def("on_step",&sc2::ReplayObserver::OnStep) - .def("on_game_end",&sc2::ReplayObserver::OnGameEnd); - py::enum_<sc2::Race>(m, "Race") .value("Terran", sc2::Race::Terran) .value("Zerg", sc2::Race::Zerg) @@ -93,6 +88,19 @@ PYBIND11_MODULE(library, m) .def_property_readonly("gas", &IDABot::GetGas, "How much gas we currently have") .def_property_readonly("current_frame", &IDABot::GetCurrentFrame, "Which frame we are currently on"); + + + py::class_<sc2::ReplayObserver>(m, "ReplayObserver") + .def(py::init()); + + py::class_<IDAReplayObserver, PyReplayObserver, sc2::ReplayObserver>(m, "IDAReplayObserver") + .def(py::init()) + .def("on_game_start", &IDAReplayObserver::OnGameStart) + .def("on_step", &IDAReplayObserver::OnStep) + .def("on_game_end", &IDAReplayObserver::OnGameEnd) + .def("get_all_units", &IDAReplayObserver::GetAllUnits, "Returns a list of all units") + ; + py::class_<sc2::PlayerSetup>(m, "PlayerSetup"); py::enum_<sc2::Difficulty>(m, "Difficulty") diff --git a/python-api-src/library.h b/python-api-src/library.h index f6a91826d3c43329ac5e001c3b7e20debb00cd4c..89118af074c70399f6cb23c63fe5689775567451 100644 --- a/python-api-src/library.h +++ b/python-api-src/library.h @@ -3,6 +3,7 @@ #include <pybind11/pybind11.h> #include <sc2api/sc2_api.h> #include "../src/IDABot.h" +#include "../src/IDAReplayObserver.h" #include <iostream> #include <pybind11/stl.h> /* Automatic conversion from std::vector to Python lists */ #include <pybind11/operators.h> /* Convenient operator support */ @@ -61,15 +62,16 @@ public: } }; -//todo fixa! -class PyReplayObserver : public sc2::ReplayObserver + +class PyReplayObserver : public IDAReplayObserver { public: + using IDAReplayObserver::IDAReplayObserver; void OnGameStart() override { PYBIND11_OVERLOAD_NAME( void, - sc2::ReplayObserver, + IDAReplayObserver, "on_game_start", OnGameStart ); @@ -78,17 +80,28 @@ public: { PYBIND11_OVERLOAD_NAME( void, - sc2::ReplayObserver, + IDAReplayObserver, "on_step", OnStep ); } + void OnGameEnd() override + { + PYBIND11_OVERLOAD_NAME( + void, + IDAReplayObserver, + "on_game_end", + OnGameEnd + + ); + } }; // The functions below are all defined in different .cpp files, in order // to keep compilation snappy void define_typeenums(pybind11::module & m); void define_unit(pybind11::module & m); +void define_replay_unit(pybind11::module & m); void define_unittype(pybind11::module &m); void define_util(pybind11::module &m); void define_point(pybind11::module &m); diff --git a/src/IDAReplayObserver.cpp b/src/IDAReplayObserver.cpp index 5a7885c3d51d1c833f7981c2d0a55ac9b1d49961..13728f586a0b3293b36dd0d5c89c4e3b1f7ff54e 100644 --- a/src/IDAReplayObserver.cpp +++ b/src/IDAReplayObserver.cpp @@ -1 +1,60 @@ #include "IDAReplayObserver.h" +#include "Util.h" + +void IDAReplayObserver::setUnits() +{ + m_allUnits.clear(); + for (auto & unit : Observation()->GetUnits()) + { + m_allUnits.push_back(ReplayUnit(unit, *this)); + } +} + +IDAReplayObserver::IDAReplayObserver(): + sc2::ReplayObserver() +{ +} + +void IDAReplayObserver::OnGameStart() +{ + setUnits(); +} + +void IDAReplayObserver::OnStep() +{ + setUnits(); +} + +void IDAReplayObserver::OnGameEnd() +{ +} + +void IDAReplayObserver::OnUnitDestroyed(const sc2::Unit* unit) +{ + ReplayUnit unitInformation = ReplayUnit(unit, *this); + OnUnitInfomationDestroyed(&unitInformation); +} + +void IDAReplayObserver::OnUnitInfomationDestroyed(const ReplayUnit *) +{ +} + + + +ReplayUnit IDAReplayObserver::GetUnit(const CCUnitID tag) const +{ + return ReplayUnit(Observation()->GetUnit(tag), *(IDAReplayObserver *)this); +} + + + + + + +const std::vector<ReplayUnit>& IDAReplayObserver::GetAllUnits() const +{ + + return m_allUnits; +} + + diff --git a/src/IDAReplayObserver.h b/src/IDAReplayObserver.h index 39b9aa420d4431ba786725985a760733ccebcf8e..b23393a2ef8b4e94e9e98e6c8f02b36fd90fd443 100644 --- a/src/IDAReplayObserver.h +++ b/src/IDAReplayObserver.h @@ -4,38 +4,27 @@ #include <limits> #include "Common.h" +#include "ReplayUnit.h" -#include "MapTools.h" -#include "BaseLocationManager.h" -#include "UnitInfoManager.h" -#include "BuildingPlacer.h" -#include "TechTree.h" -#include "TechTreeImproved.h" -#include "MetaType.h" -#include "Unit.h" +class ReplayUnit; class IDAReplayObserver : public sc2::ReplayObserver { - MapTools m_map; - BaseLocationManager m_bases; - UnitInfoManager m_unitInfo; - TechTree m_techTree; - BuildingPlacer m_buildingPlacer; - - std::vector<Unit> m_allUnits; - std::vector<CCPosition> m_baseLocations; - void setUnits(); - void OnError(const std::vector<sc2::ClientError> & client_errors, - const std::vector<std::string> & protocol_errors = {}) override; + std::vector<ReplayUnit> m_allUnits; public: IDAReplayObserver(); - + void OnGameStart() override; void OnStep() override; + void OnGameEnd() override; + void OnUnitDestroyed(const sc2::Unit*) override; + void OnUnitInfomationDestroyed(const ReplayUnit*); + + ReplayUnit GetUnit(const CCUnitID tag) const; - const std::vector<Unit> & GetAllUnits() const; + const std::vector<ReplayUnit> & GetAllUnits() const; }; diff --git a/src/ReplayUnit.cpp b/src/ReplayUnit.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1bab9381607394099b2d292bc4c6281db9da854f --- /dev/null +++ b/src/ReplayUnit.cpp @@ -0,0 +1,71 @@ +#include "ReplayUnit.h" + + + +ReplayUnit::ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver) + : m_replayObserver(&replayObserver), Unit(unit) +{ + +} + + std::string ReplayUnit::getType() const +{ + return m_unit->unit_type.to_string(); + +} + + std::string ReplayUnit::getTypeName() const + { + return sc2::UnitTypeToName(m_unit->unit_type); + } + +bool ReplayUnit::hasTarget() const +{ + BOT_ASSERT(isValid(), "Unit is not valid"); + std::cout << "HAS TARGET" << std::endl; + if (getUnitPtr()->orders.size() > 0) { + if (getUnitPtr()->orders[0].target_unit_tag != NULL) { + CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag; + //The tag is for somereason a null tag + if (t_id == sc2::NullTag) { + return false; + } + std::cout << "1MID HAS TARGET" << std::endl; + std::cout << "2MID HAS TARGET" << std::endl; + std::cout << "valid" << m_replayObserver->GetUnit(t_id).getType() << std::endl; + std::cout << "AFTER" << std::endl; + // IDABot finds the unit with this tag, and returns true if valid + return m_replayObserver->GetUnit(t_id).isValid(); + } + } + std::cout << "END HAS TARGET" << std::endl; + + return false; +} + +ReplayUnit ReplayUnit::getTarget() const +{ + BOT_ASSERT(isValid(), "Unit is not valid"); + + + // if unit has order, check tag of target of first order + if (getUnitPtr()->orders.size() > 0) { + // t_id is set to the unit tag of the target + CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag; + //The tag is for somereason a null tag + if (t_id == sc2::NullTag) { + + std::cout << "nullTAG" << std::endl; + std::cout << "type " << sc2::UnitTypeToName(m_unit->unit_type) <<"pos " << getPosition().x << " x y "<< getPosition().y << ", id " << getID() << "player " << getPlayer() << std::endl; + std::cout << getUnitPtr()->orders.size() << std::endl; + return *this; + } + // IDAReplayObserver finds the unit with this tag + return m_replayObserver->GetUnit(t_id); + } + + ReplayUnit this_unit = ReplayUnit(m_unit, *m_replayObserver); + return this_unit; +} + + diff --git a/src/ReplayUnit.h b/src/ReplayUnit.h new file mode 100644 index 0000000000000000000000000000000000000000..171dfc6823739b7180ee2a07d119f9a2ed5788f2 --- /dev/null +++ b/src/ReplayUnit.h @@ -0,0 +1,20 @@ +#pragma once +#include "Unit.h" +#include "IDAReplayObserver.h" + +class IDAReplayObserver; + +//! A Unit that have a replayobserver insted of an Agent, +class ReplayUnit: public Unit +{ + mutable IDAReplayObserver * m_replayObserver; + +public: + ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver); + + std::string getType() const; + std::string getTypeName() const; + bool hasTarget() const; + ReplayUnit getTarget() const; + +}; \ No newline at end of file diff --git a/src/Unit.cpp b/src/Unit.cpp index 9cc165a2b2a1c72aec083346c0a0235f3fc513da..3ca5c6ca465100368d06a14b0a05a4e8ee94a19a 100644 --- a/src/Unit.cpp +++ b/src/Unit.cpp @@ -18,6 +18,13 @@ Unit::Unit(const sc2::Unit * unit, IDABot & bot) } +Unit::Unit(const sc2::Unit * unit) + : m_unit(unit) + , m_unitID(unit->tag) +{ + +} + const sc2::Unit * Unit::getUnitPtr() const { return m_unit; diff --git a/src/Unit.h b/src/Unit.h index 5a7cef49664e37ee60a197e26a12cbf8065f9854..24a3d631624e4eaf1001431f85665b85396b9513 100644 --- a/src/Unit.h +++ b/src/Unit.h @@ -11,13 +11,15 @@ class Unit CCUnitID m_unitID; UnitType m_unitType; - const sc2::Unit * m_unit; +protected: + const sc2::Unit * m_unit; public: Unit(); Unit(const sc2::Unit * unit, IDABot & bot); + Unit(const sc2::Unit * unit); const sc2::Unit * getUnitPtr() const; const sc2::UnitTypeID & getAPIUnitType() const;