diff --git a/src/rts/GameState.java b/src/rts/GameState.java index a87ef5f4..c5a4e33c 100644 --- a/src/rts/GameState.java +++ b/src/rts/GameState.java @@ -52,7 +52,7 @@ public class GameState { // 4: unit type // 5: current unit action // 6: wall - public static final int numVectorObservationFeatureMaps = 6; + private static final int NUM_VECTOR_OBSERVATION_FEATURE_MAPS = 6; /** * Initializes the GameState with a PhysicalGameState and a UnitTypeTable @@ -921,7 +921,7 @@ public static GameState fromJSON(String JSON, UnitTypeTable utt) { */ public int [][][] getVectorObservation(final int player){ if (vectorObservation == null) { - vectorObservation = new int[2][numVectorObservationFeatureMaps][pgs.height][pgs.width]; + vectorObservation = new int[2][NUM_VECTOR_OBSERVATION_FEATURE_MAPS][pgs.height][pgs.width]; } // hitpointsMatrix is vectorObservation[player][0] // resourcesMatrix is vectorObservation[player][1] diff --git a/src/rts/PartiallyObservableGameState.java b/src/rts/PartiallyObservableGameState.java index bfc78242..cff572fb 100644 --- a/src/rts/PartiallyObservableGameState.java +++ b/src/rts/PartiallyObservableGameState.java @@ -1,7 +1,10 @@ package rts; +import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; + import rts.units.Unit; /** @@ -10,10 +13,19 @@ * @author santi */ public class PartiallyObservableGameState extends GameState { - /** - * - */ - protected int player; // the observer player + + protected int observer; // the observer player + + // Feature maps: + // 1: hit points + // 2: resources + // 3: player + // 4: unit type + // 5: current unit action + // 6: walls + // 7: which cells can I see? + // 8: for which cells do I know that my opponent can see them? + public static final int NUM_VECTOR_OBSERVATION_FEATURE_MAPS_PARTIAL_OBS = 8; /** * Creates a partially observable game state, from the point of view of 'player': @@ -25,19 +37,19 @@ public PartiallyObservableGameState(GameState gs, int a_player) { unitCancelationCounter = gs.unitCancelationCounter; time = gs.time; - player = a_player; + observer = a_player; unitActions.putAll(gs.unitActions); - List toDelete = new LinkedList<>(); - for (Unit u : pgs.getUnits()) { - if (u.getPlayer() != player) { + final List toDelete = new LinkedList<>(); + for (final Unit u : pgs.getUnits()) { + if (u.getPlayer() != observer) { if (!observable(u.getX(), u.getY())) { toDelete.add(u); } } } - for (Unit u : toDelete) + for (final Unit u : toDelete) removeUnit(u); } @@ -45,11 +57,12 @@ public PartiallyObservableGameState(GameState gs, int a_player) { * Returns whether the position is within view of the player * @see rts.GameState#observable(int, int) */ - public boolean observable(int x, int y) { - for (Unit u : pgs.getUnits()) { - if (u.getPlayer() == player) { - double d = Math.sqrt((u.getX() - x) * (u.getX() - x) + (u.getY() - y) * (u.getY() - y)); - if (d <= u.getType().sightRadius) + @Override + public boolean observable(final int x, final int y) { + for (final Unit u : pgs.getUnits()) { + if (u.getPlayer() == observer) { + final int dSquared = (u.getX() - x) * (u.getX() - x) + (u.getY() - y) * (u.getY() - y); + if (dSquared <= u.getType().sightRadius * u.getType().sightRadius) return true; } } @@ -60,7 +73,108 @@ public boolean observable(int x, int y) { /* (non-Javadoc) * @see rts.GameState#clone() */ - public PartiallyObservableGameState clone() { - return new PartiallyObservableGameState(super.clone(), player); + @Override + public PartiallyObservableGameState clone() { + return new PartiallyObservableGameState(super.clone(), observer); + } + + @Override + public int [][][] getVectorObservation(final int player){ + if (vectorObservation == null) { + vectorObservation = new int[2][NUM_VECTOR_OBSERVATION_FEATURE_MAPS_PARTIAL_OBS][pgs.height][pgs.width]; + } + + List friendlyUnits = new ArrayList<>(); + List enemyUnits = new ArrayList<>(); + + // hitpointsMatrix is vectorObservation[player][0] + // resourcesMatrix is vectorObservation[player][1] + // playersMatrix is vectorObservation[player][2] + // unitTypesMatrix is vectorObservation[player][3] + // unitActionMatrix is vectorObservation[player][4] + // wallMatrix is vectorObservation[player][5] + // myVisibilityMatrix is vectorObservation[player][6] + // opponentVisibilityMatrix is vectorObservation[player][7] + + for (int i=0; i= 0) { // Owned by a player, not neutral + vectorObservation[player][2][u.getY()][u.getX()] = ((u.getPlayer() + player) % 2) + 1; + + // Split units based on owner (used for last two layers of the observation) + if (owner == player) + friendlyUnits.add(new int[]{u.getX(), u.getY(), u.getType().sightRadius}); + else + enemyUnits.add(new int[]{u.getX(), u.getY(), u.getType().sightRadius}); + } + + vectorObservation[player][3][u.getY()][u.getX()] = u.getType().ID + 1; + + if (uaa != null) { + vectorObservation[player][4][u.getY()][u.getX()] = uaa.action.type; + } else { + // Commented line of code is unnecessary: already initialised to 0 + //vectorObservation[player][4][u.getY()][u.getX()] = UnitAction.TYPE_NONE; + } + } + + // Encode the presence of walls + final int[] terrain = pgs.terrain; + for (int y = 0; y < pgs.height; ++y) { + System.arraycopy(terrain, y * pgs.width, vectorObservation[player][5][y], 0, pgs.width); + } + + // Encode visibility + final int[][] playerVisibility = calculateVisibility(friendlyUnits, pgs.width, pgs.height); + final int[][] opponentVisibility = calculateVisibility(enemyUnits, pgs.width, pgs.height); + + for (int y = 0; y < pgs.height; y++) { + System.arraycopy(playerVisibility, y * pgs.width, vectorObservation[player][6][y], 0, pgs.width); + System.arraycopy(opponentVisibility, y * pgs.width, vectorObservation[player][7][y], 0, pgs.width); + } + + return vectorObservation[player]; + } + + private static int[][] calculateVisibility(final List units, final int width, final int height) { + final int[][] visibility = new int[height][width]; + for (final int[] unit : units) { + final int ux = unit[0]; + final int uy = unit[1]; + final int sightRadius = unit[2]; + final int sightRadiusSquared = sightRadius * sightRadius; + + for (int dy = -sightRadius; dy <= sightRadius; dy++) { + for (int dx = -sightRadius; dx <= sightRadius; dx++) { + final int x = ux + dx; + final int y = uy + dy; + + if (x >= 0 && x < width && y >= 0 && y < height) { + final int distanceSquared = dx * dx + dy * dy; + if (distanceSquared <= sightRadiusSquared) { + visibility[y][x] = 1; + } + } + } + } + } + return visibility; } }