Skip to content
Snippets Groups Projects
Commit 40b6496a authored by Edvin Bergström's avatar Edvin Bergström
Browse files

Fix unit type

parent cd627e83
No related branches found
No related tags found
1 merge request!6Replays
...@@ -6,8 +6,7 @@ void define_replay_unit(py::module & m) ...@@ -6,8 +6,7 @@ void define_replay_unit(py::module & m)
{ {
py::class_<ReplayUnit>(m, "ReplayUnit") py::class_<ReplayUnit>(m, "ReplayUnit")
.def_property_readonly("id", &ReplayUnit::getID) .def_property_readonly("id", &ReplayUnit::getID)
.def_property_readonly("unit_type", &ReplayUnit::getType, "The id of the type") .def_property_readonly("unit_type", &ReplayUnit::getType, "The :class :`library.UnitType` of the unit")
.def_property_readonly("unit_type_name", &ReplayUnit::getTypeName, "The name of the type")
.def_property_readonly("position", &ReplayUnit::getPosition, "The :class:`library.Point2D` of the unit") .def_property_readonly("position", &ReplayUnit::getPosition, "The :class:`library.Point2D` of the unit")
.def_property_readonly("tile_position", &ReplayUnit::getTilePosition, "The :class:`library.Point2DI` of the unit") .def_property_readonly("tile_position", &ReplayUnit::getTilePosition, "The :class:`library.Point2DI` of the unit")
.def_property_readonly("hit_points", &ReplayUnit::getHitPoints) .def_property_readonly("hit_points", &ReplayUnit::getHitPoints)
...@@ -38,6 +37,6 @@ void define_replay_unit(py::module & m) ...@@ -38,6 +37,6 @@ void define_replay_unit(py::module & m)
.def_property_readonly("is_carrying_minerals", &ReplayUnit::isCarryingMinerals) .def_property_readonly("is_carrying_minerals", &ReplayUnit::isCarryingMinerals)
.def("__hash__", [](const ReplayUnit & unit) { return std::hash<const sc2::Unit *>{}(unit.getUnitPtr()); }) .def("__hash__", [](const ReplayUnit & unit) { return std::hash<const sc2::Unit *>{}(unit.getUnitPtr()); })
.def(py::self == py::self) .def(py::self == py::self)
.def("__repr__", [](const ReplayUnit & unit) { return "<Unit of type: '" + unit.getTypeName() +" player: " + std::to_string(unit.getPlayer()) +">"; }) .def("__repr__", [](const ReplayUnit & unit) { return "<Unit of type: '" + unit.getType().getName() + "'>"; })
; ;
} }
...@@ -26,5 +26,7 @@ void define_tech_tree(py::module & m) ...@@ -26,5 +26,7 @@ void define_tech_tree(py::module & m)
py::class_<TechTree>(m, "TechTree") py::class_<TechTree>(m, "TechTree")
.def("get_data", py::overload_cast<const UnitType &>(&TechTree::getData, py::const_)) .def("get_data", py::overload_cast<const UnitType &>(&TechTree::getData, py::const_))
.def("get_data", py::overload_cast<const CCUpgrade &>(&TechTree::getData, py::const_)); .def("get_data", py::overload_cast<const CCUpgrade &>(&TechTree::getData, py::const_))
.def("suppress_warnings", &TechTree::setSuppressWarnings, "Suppress type and uppgrade warnings" ,"b"_a)
;
} }
...@@ -142,6 +142,7 @@ PYBIND11_MODULE(library, m) ...@@ -142,6 +142,7 @@ PYBIND11_MODULE(library, m)
.def("get_result_for_player", &IDAReplayObserver::GetResultForPlayer, "player_id"_a) .def("get_result_for_player", &IDAReplayObserver::GetResultForPlayer, "player_id"_a)
.def("on_unit_destroyed", &IDAReplayObserver::OnReplayUnitDestroyed, "unit"_a) .def("on_unit_destroyed", &IDAReplayObserver::OnReplayUnitDestroyed, "unit"_a)
.def("on_unit_created", &IDAReplayObserver::OnReplayUnitCreated, "unit"_a) .def("on_unit_created", &IDAReplayObserver::OnReplayUnitCreated, "unit"_a)
.def_property_readonly("tech_tree", &IDAReplayObserver::GetTechTree)
; ;
......
...@@ -72,12 +72,11 @@ void IDABot::OnStep() ...@@ -72,12 +72,11 @@ void IDABot::OnStep()
m_map.onFrame(); m_map.onFrame();
m_unitInfo.onFrame(); m_unitInfo.onFrame();
m_bases.onFrame(); m_bases.onFrame();
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// Draw debug interface, and send debug interface to the Sc2 client. // Draw debug interface, and send debug interface to the Sc2 client.
// ----------------------------------------------------------------- // -----------------------------------------------------------------
Debug()->SendDebug(); Debug()->SendDebug();
m_buildingPlacer.drawReservedTiles(); m_buildingPlacer.drawReservedTiles();
} }
void IDABot::setUnits() void IDABot::setUnits()
......
...@@ -16,13 +16,16 @@ void IDAReplayObserver::setUnits() ...@@ -16,13 +16,16 @@ void IDAReplayObserver::setUnits()
} }
IDAReplayObserver::IDAReplayObserver(): IDAReplayObserver::IDAReplayObserver():
sc2::ReplayObserver() sc2::ReplayObserver(),
m_techTree(*this)
{ {
} }
void IDAReplayObserver::OnGameStart() void IDAReplayObserver::OnGameStart()
{ {
setUnits(); setUnits();
m_techTree.onStart();
} }
void IDAReplayObserver::OnStep() void IDAReplayObserver::OnStep()
...@@ -103,4 +106,29 @@ sc2::GameResult IDAReplayObserver::GetResultForPlayer(int player) ...@@ -103,4 +106,29 @@ sc2::GameResult IDAReplayObserver::GetResultForPlayer(int player)
return ReplayControl()->GetReplayInfo().players[player].game_result; return ReplayControl()->GetReplayInfo().players[player].game_result;
} }
const TechTree & IDAReplayObserver::GetTechTree() const
{
return m_techTree;
}
const TypeData & IDAReplayObserver::Data(const UnitType & type) const
{
return m_techTree.getData(type);
}
const TypeData & IDAReplayObserver::Data(const CCUpgrade & type) const
{
return m_techTree.getData(type);
}
const TypeData & IDAReplayObserver::Data(const MetaType & type) const
{
return m_techTree.getData(type);
}
const TypeData & IDAReplayObserver::Data(const Unit & unit) const
{
return m_techTree.getData(unit.getType());
}
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "Common.h" #include "Common.h"
#include "ReplayUnit.h" #include "ReplayUnit.h"
#include "TechTree.h"
class ReplayUnit; class ReplayUnit;
...@@ -13,6 +15,9 @@ class IDAReplayObserver : public sc2::ReplayObserver ...@@ -13,6 +15,9 @@ class IDAReplayObserver : public sc2::ReplayObserver
void setUnits(); void setUnits();
std::vector<ReplayUnit> m_allUnits; std::vector<ReplayUnit> m_allUnits;
std::set<CCUnitID> m_allUnitsID; std::set<CCUnitID> m_allUnitsID;
TechTree m_techTree;
public: public:
IDAReplayObserver(); IDAReplayObserver();
...@@ -29,10 +34,18 @@ public: ...@@ -29,10 +34,18 @@ public:
ReplayUnit GetUnit(const CCUnitID tag) const; ReplayUnit GetUnit(const CCUnitID tag) const;
bool UnitExists(const CCUnitID tag) const; bool UnitExists(const CCUnitID tag) const;
const std::vector<ReplayUnit> & GetAllUnits() const; const std::vector<ReplayUnit> & GetAllUnits() const;
CCRace GetPlayerRace(int player); CCRace GetPlayerRace(int player);
std::string GetReplayPath(); std::string GetReplayPath();
sc2::GameResult GetResultForPlayer(int player); sc2::GameResult GetResultForPlayer(int player);
const TechTree & GetTechTree() const;
const TypeData & Data(const UnitType & type) const;
const TypeData & Data(const CCUpgrade & type) const;
const TypeData & Data(const MetaType & type) const;
const TypeData & Data(const Unit & unit) const;
}; };
...@@ -3,22 +3,16 @@ ...@@ -3,22 +3,16 @@
ReplayUnit::ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver) ReplayUnit::ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver)
: m_replayObserver(&replayObserver), Unit(unit) : m_replayObserver(&replayObserver), Unit(unit), m_type(unit->unit_type, replayObserver, replayObserver)
{ {
} }
std::string ReplayUnit::getType() const const UnitType & ReplayUnit::getType() const
{ {
return m_unit->unit_type.to_string(); return m_type;
} }
std::string ReplayUnit::getTypeName() const
{
return sc2::UnitTypeToName(m_unit->unit_type);
}
bool ReplayUnit::hasTarget() const bool ReplayUnit::hasTarget() const
{ {
BOT_ASSERT(isValid(), "Unit is not valid"); BOT_ASSERT(isValid(), "Unit is not valid");
...@@ -51,8 +45,7 @@ ReplayUnit ReplayUnit::getTarget() const ...@@ -51,8 +45,7 @@ ReplayUnit ReplayUnit::getTarget() const
return m_replayObserver->GetUnit(t_id); return m_replayObserver->GetUnit(t_id);
} }
} }
ReplayUnit this_unit = ReplayUnit(m_unit, *m_replayObserver); return *this;
return this_unit;
} }
int ReplayUnit::getPlayer() const int ReplayUnit::getPlayer() const
......
...@@ -8,12 +8,12 @@ class IDAReplayObserver; ...@@ -8,12 +8,12 @@ class IDAReplayObserver;
class ReplayUnit: public Unit class ReplayUnit: public Unit
{ {
mutable IDAReplayObserver * m_replayObserver; mutable IDAReplayObserver * m_replayObserver;
UnitType m_type;
public: public:
ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver); ReplayUnit(const sc2::Unit * unit, IDAReplayObserver & replayObserver);
std::string getType() const; const UnitType & getType() const;
std::string getTypeName() const;
bool hasTarget() const; bool hasTarget() const;
ReplayUnit getTarget() const; ReplayUnit getTarget() const;
int getPlayer() const; int getPlayer() const;
......
This diff is collapsed.
...@@ -31,18 +31,21 @@ struct TypeData ...@@ -31,18 +31,21 @@ struct TypeData
class TechTree class TechTree
{ {
IDABot & m_bot; sc2::Client & m_client;
std::map<UnitType, TypeData> m_unitTypeData; std::map<UnitType, TypeData> m_unitTypeData;
std::map<CCUpgrade, TypeData> m_upgradeData; std::map<CCUpgrade, TypeData> m_upgradeData;
void initUnitTypeData(); void initUnitTypeData();
void initUpgradeData(); void initUpgradeData();
bool suppressWarnings;
public: public:
TechTree(IDABot & bot); TechTree(sc2::Client & client);
void onStart(); void onStart();
void setSuppressWarnings(bool b);
const TypeData & getData(const UnitType & type) const; const TypeData & getData(const UnitType & type) const;
const TypeData & getData(const CCUpgrade & type) const; const TypeData & getData(const CCUpgrade & type) const;
const TypeData & getData(const MetaType & type) const; const TypeData & getData(const MetaType & type) const;
......
...@@ -13,7 +13,7 @@ Unit::Unit(const sc2::Unit * unit, IDABot & bot) ...@@ -13,7 +13,7 @@ Unit::Unit(const sc2::Unit * unit, IDABot & bot)
: m_bot(&bot) : m_bot(&bot)
, m_unit(unit) , m_unit(unit)
, m_unitID(unit->tag) , m_unitID(unit->tag)
, m_unitType(unit->unit_type, bot) , m_unitType(unit->unit_type, bot, bot)
{ {
} }
...@@ -335,17 +335,14 @@ void Unit::ability(sc2::AbilityID ability, const Unit& target) const ...@@ -335,17 +335,14 @@ void Unit::ability(sc2::AbilityID ability, const Unit& target) const
Unit Unit::getTarget() const Unit Unit::getTarget() const
{ {
BOT_ASSERT(isValid(), "Unit is not valid"); BOT_ASSERT(isValid(), "Unit is not valid");
// if unit has order, check tag of target of first order // if unit has order, check tag of target of first order
if(getUnitPtr()->orders.size() > 0){ if(getUnitPtr()->orders.size() > 0){
// t_id is set to the unit tag of the target // t_id is set to the unit tag of the target
CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag; CCUnitID t_id = getUnitPtr()->orders[0].target_unit_tag;
// IDABot finds the unit with this tag // IDABot finds the unit with this tag
return m_bot->GetUnit(t_id); return m_bot->GetUnit(t_id);
} }
return *this;
Unit this_unit = Unit(m_unit, *m_bot);
return this_unit;
} }
bool Unit::hasTarget() const bool Unit::hasTarget() const
......
#include "UnitType.h" #include "UnitType.h"
#include "IDABot.h" #include "IDABot.h"
#include "IDAReplayObserver.h"
UnitType::UnitType() UnitType::UnitType()
: m_bot(nullptr) : m_client(nullptr)
, m_type(0) , m_type(0)
{ {
} }
UnitType::UnitType(const sc2::UnitTypeID & type, IDABot & bot) UnitType::UnitType(const sc2::UnitTypeID & type, sc2::Client & client)
: m_bot(&bot) : m_client(&client)
, m_type(type) , m_type(type)
, m_bot(nullptr)
, m_observer(nullptr)
{ {
} }
UnitType::UnitType(const sc2::UnitTypeID & type, sc2::Client & client, IDABot & bot)
: m_client(&client)
, m_type(type)
, m_bot(&bot)
, m_observer(nullptr)
{
}
UnitType::UnitType(const sc2::UnitTypeID & type, sc2::Client & client, IDAReplayObserver & observer)
: m_client(&client)
, m_type(type)
, m_observer(&observer)
, m_bot(nullptr)
{
}
sc2::UnitTypeID UnitType::getAPIUnitType() const sc2::UnitTypeID UnitType::getAPIUnitType() const
{ {
return m_type; return m_type;
...@@ -47,7 +70,7 @@ std::string UnitType::getName() const ...@@ -47,7 +70,7 @@ std::string UnitType::getName() const
CCRace UnitType::getRace() const CCRace UnitType::getRace() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].race; return m_client->Observation()->GetUnitTypeData()[m_type].race;
} }
bool UnitType::isCombatUnit() const bool UnitType::isCombatUnit() const
...@@ -192,7 +215,7 @@ bool UnitType::isWorker() const ...@@ -192,7 +215,7 @@ bool UnitType::isWorker() const
CCPositionType UnitType::getAttackRange() const CCPositionType UnitType::getAttackRange() const
{ {
#ifdef SC2API #ifdef SC2API
auto & weapons = m_bot->Observation()->GetUnitTypeData()[m_type].weapons; auto & weapons = m_client->Observation()->GetUnitTypeData()[m_type].weapons;
if (weapons.empty()) if (weapons.empty())
{ {
...@@ -220,7 +243,20 @@ int UnitType::tileWidth() const ...@@ -220,7 +243,20 @@ int UnitType::tileWidth() const
#ifdef SC2API #ifdef SC2API
if (isMineral()) { return 2; } if (isMineral()) { return 2; }
if (isGeyser()) { return 3; } if (isGeyser()) { return 3; }
else { return (int)(2 * m_bot->Observation()->GetAbilityData()[m_bot->Data(*this).buildAbility].footprint_radius); } else {
if (m_bot != nullptr)
{
return (int)(2 * m_client->Observation()->GetAbilityData()[m_bot->Data(*this).buildAbility].footprint_radius);
}
else if (m_observer != nullptr)
{
return (int)(2 * m_client->Observation()->GetAbilityData()[m_observer->Data(*this).buildAbility].footprint_radius);
}
else
{
return -1;
}
}
#else #else
return m_type.tileWidth(); return m_type.tileWidth();
#endif #endif
...@@ -231,16 +267,42 @@ int UnitType::tileHeight() const ...@@ -231,16 +267,42 @@ int UnitType::tileHeight() const
#ifdef SC2API #ifdef SC2API
if (isMineral()) { return 1; } if (isMineral()) { return 1; }
if (isGeyser()) { return 3; } if (isGeyser()) { return 3; }
else { return (int)(2 * m_bot->Observation()->GetAbilityData()[m_bot->Data(*this).buildAbility].footprint_radius); } else {
if (m_bot != nullptr)
{
return (int)(2 * m_client->Observation()->GetAbilityData()[m_bot->Data(*this).buildAbility].footprint_radius);
}
else if (m_observer != nullptr)
{
return (int)(2 * m_client->Observation()->GetAbilityData()[m_observer->Data(*this).buildAbility].footprint_radius);
}
else
{
return -1;
}
}
#else #else
return m_type.tileHeight(); return m_type.tileHeight();
#endif #endif
} }
bool UnitType::isAddon() const bool UnitType::isAddon() const
{ {
#ifdef SC2API #ifdef SC2API
return m_bot->Data(*this).isAddon; if (m_bot != nullptr)
{
return m_bot->Data(*this).isAddon;
}
else if (m_observer != nullptr)
{
return m_observer->Data(*this).isAddon;
}
else
{
return false;
}
#else #else
return m_type.isAddon(); return m_type.isAddon();
#endif #endif
...@@ -249,7 +311,19 @@ bool UnitType::isAddon() const ...@@ -249,7 +311,19 @@ bool UnitType::isAddon() const
bool UnitType::isBuilding() const bool UnitType::isBuilding() const
{ {
#ifdef SC2API #ifdef SC2API
return m_bot->Data(*this).isBuilding; if (m_bot != nullptr)
{
return m_bot->Data(*this).isBuilding;
}
else if (m_observer != nullptr)
{
return m_observer->Data(*this).isBuilding;
}
else
{
return false;
}
#else #else
return m_type.isBuilding(); return m_type.isBuilding();
#endif #endif
...@@ -258,7 +332,7 @@ bool UnitType::isBuilding() const ...@@ -258,7 +332,7 @@ bool UnitType::isBuilding() const
int UnitType::supplyProvided() const int UnitType::supplyProvided() const
{ {
#ifdef SC2API #ifdef SC2API
return (int)m_bot->Observation()->GetUnitTypeData()[m_type].food_provided; return (int)m_client->Observation()->GetUnitTypeData()[m_type].food_provided;
#else #else
return m_type.supplyProvided(); return m_type.supplyProvided();
#endif #endif
...@@ -267,7 +341,7 @@ int UnitType::supplyProvided() const ...@@ -267,7 +341,7 @@ int UnitType::supplyProvided() const
int UnitType::supplyRequired() const int UnitType::supplyRequired() const
{ {
#ifdef SC2API #ifdef SC2API
return (int)m_bot->Observation()->GetUnitTypeData()[m_type].food_required; return (int)m_client->Observation()->GetUnitTypeData()[m_type].food_required;
#else #else
return m_type.supplyRequired(); return m_type.supplyRequired();
#endif #endif
...@@ -276,7 +350,7 @@ int UnitType::supplyRequired() const ...@@ -276,7 +350,7 @@ int UnitType::supplyRequired() const
int UnitType::mineralPrice() const int UnitType::mineralPrice() const
{ {
#ifdef SC2API #ifdef SC2API
return (int)m_bot->Observation()->GetUnitTypeData()[m_type].mineral_cost; return (int)m_client->Observation()->GetUnitTypeData()[m_type].mineral_cost;
#else #else
return m_type.mineralPrice(); return m_type.mineralPrice();
#endif #endif
...@@ -285,7 +359,7 @@ int UnitType::mineralPrice() const ...@@ -285,7 +359,7 @@ int UnitType::mineralPrice() const
int UnitType::gasPrice() const int UnitType::gasPrice() const
{ {
#ifdef SC2API #ifdef SC2API
return (int)m_bot->Observation()->GetUnitTypeData()[m_type].vespene_cost; return (int)m_client->Observation()->GetUnitTypeData()[m_type].vespene_cost;
#else #else
return m_type.gasPrice(); return m_type.gasPrice();
#endif #endif
...@@ -377,30 +451,30 @@ bool UnitType::isMorphedBuilding() const ...@@ -377,30 +451,30 @@ bool UnitType::isMorphedBuilding() const
int UnitType::getMovementSpeed() const int UnitType::getMovementSpeed() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].movement_speed; return m_client->Observation()->GetUnitTypeData()[m_type].movement_speed;
} }
int UnitType::getSightRange() const int UnitType::getSightRange() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].sight_range; return m_client->Observation()->GetUnitTypeData()[m_type].sight_range;
} }
UnitTypeID UnitType::getRequiredStructure() const UnitTypeID UnitType::getRequiredStructure() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].tech_requirement; return m_client->Observation()->GetUnitTypeData()[m_type].tech_requirement;
} }
std::vector<sc2::UnitTypeID> UnitType::getEquivalentUnits() const std::vector<sc2::UnitTypeID> UnitType::getEquivalentUnits() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].tech_alias; return m_client->Observation()->GetUnitTypeData()[m_type].tech_alias;
} }
bool UnitType::requiredAttached() const bool UnitType::requiredAttached() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].require_attached; return m_client->Observation()->GetUnitTypeData()[m_type].require_attached;
} }
float UnitType::getBuildTime() const float UnitType::getBuildTime() const
{ {
return m_bot->Observation()->GetUnitTypeData()[m_type].build_time; return m_client->Observation()->GetUnitTypeData()[m_type].build_time;
} }
\ No newline at end of file
...@@ -4,17 +4,24 @@ ...@@ -4,17 +4,24 @@
class IDABot; class IDABot;
class IDAReplayObserver;
class UnitType class UnitType
{ {
mutable IDABot * m_bot; mutable sc2::Client * m_client;
mutable IDABot * m_bot;
mutable IDAReplayObserver * m_observer;
sc2::UnitTypeID m_type; sc2::UnitTypeID m_type;
public: public:
UnitType(); UnitType();
UnitType(const sc2::UnitTypeID & type, IDABot & bot); UnitType(const sc2::UnitTypeID & type, sc2::Client & client);
UnitType(const sc2::UnitTypeID & type, sc2::Client & client, IDABot & m_bot);
UnitType(const sc2::UnitTypeID & type, sc2::Client & client, IDAReplayObserver & observer);
sc2::UnitTypeID getAPIUnitType() const; sc2::UnitTypeID getAPIUnitType() const;
bool is(const sc2::UnitTypeID & type) const; bool is(const sc2::UnitTypeID & type) const;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment