import dill
import numpy as np
from typing import List, Type, Optional, Dict, Tuple
from spacetime import Dataframe
from .data_model import Observation, ServerState, Player
from .FrameRateKeeper import FrameRateKeeper
from .BaseEnvironment import BaseEnvironment
import logging
logger = logging.getLogger(__name__)
[docs]class ClientEnvironment:
""" Basic client environment that will work with any server environment and handles most of the connection
and playing capabilities.
"""
_TickRate = 60
def __init__(self,
dataframe: Dataframe,
dimensions: List[str],
observation_class: Type[Observation],
host: str,
server_environment: Optional[Type[BaseEnvironment]] = None,
auth_key: str = ''):
""" The primary class for interacting with the environment as a remote client.
Parameters
----------
dataframe : Dataframe
The spacetime dataframe connected to the game server.
dimensions : List[str]
The names of the observation dimensions
observation_class : Type[Observation]
The base class of observations in the dataframe.
host : str
The hostname of the game server
server_environment : Optional[Type[BaseEnvironment]]
The full server environment if we have access to it.
auth_key : str
Your authorization key for entering the game if the server has a whitelist.
"""
self.player_df: Dataframe = dataframe
self.observation_df: Optional[Dataframe] = None
self._server_state: ServerState = self.player_df.read_all(ServerState)[0]
assert self._server_state.terminal is False, "Connecting to a server with no active game."
assert self._server_state.server_no_longer_joinable is False, "Server is not accepting new connection."
self._player: Optional[Player] = None
self._dimensions: List[str] = dimensions
self._observation: Observation = None
self._observation_class: Type[Observation] = observation_class
self._host: str = host
self._auth_key: str = auth_key
self._server_environment: Optional[BaseEnvironment] = None
if server_environment is not None:
self._server_environment = server_environment(self._server_state.env_config)
self.fr: FrameRateKeeper = FrameRateKeeper(self._TickRate)
self.connected: bool = False
[docs] def pull_dataframe(self) -> None:
""" Helper function to update all dataframes for this environment. """
self.player_df.pull()
self.player_df.checkout()
if self.observation_df is not None:
self.observation_df.pull()
self.observation_df.checkout()
[docs] def push_dataframe(self) -> None:
""" Helper function to push all dataframes for this environment. """
self.player_df.commit()
self.player_df.push()
[docs] def check_connection(self) -> None:
""" Helper function to error out if we are not yet connected to a game server.
Raises
------
ConnectionError
If connect() has not been called yet.
"""
if not self.connected:
raise ConnectionError("Not connected to game server.")
[docs] def tick(self) -> bool:
""" Helper function to wait for a tick of the framerate.
Returns
-------
bool
Whether or not the framerate keeper has raised a timeout.
"""
return self.fr.tick()
@property
def observation(self) -> Dict[str, np.ndarray]:
""" Get the current observation present for this agent.
Returns
-------
Dict[str, np.ndarray]
The observation dictionary for this environment.
"""
self.check_connection()
return {dimension: getattr(self._observation, dimension) for dimension in self.dimensions}
@property
def terminal(self) -> bool:
""" Check if the game has ended for us or not.
Returns
-------
bool
Whether or not the game has reached a terminal state.
Raises
------
ConnectionError
If connect() has not been called yet.
"""
self.check_connection()
return self._server_state.terminal
@property
def winners(self) -> Optional[List[int]]:
""" Get the current list of winners for the game.
Returns
-------
List[int]
The list of player numbers of the winners.
Raises
------
ConnectionError
If connect() has not been called yet.
ValueError
If the game is not over yet.
"""
if not self.terminal:
raise ValueError("Game has not ended yet.")
return dill.loads(self._server_state.winners)
@property
def server_environment(self) -> Optional[BaseEnvironment]:
""" Get the full server environment object if we have it available.
Returns
-------
BaseEnvironment
Server environment or None is not available.
"""
return self._server_environment
@property
def dimensions(self) -> List[str]:
""" Get all of the observations that we recieve from the server.
Returns
-------
List[str]
The keys in the observation dictionary.
"""
return self._dimensions
@property
def full_state(self):
""" Full server state for the game if the environment and the server support it.
Returns
-------
object
Current server state
Raises
------
ConnectionError
If connect() has not been called yet.
ValueError
If we do not have access to the full state.
"""
self.check_connection()
if not self.server_environment.serializable():
raise ValueError("Current Environment does not support full state for clients.")
return self.server_environment.deserialize_state(self._server_state.serialized_state)
[docs] def connect(self, username: str, timeout: Optional[float] = None) -> int:
""" Connect to the remote server and wait for the game to start.
Parameters
----------
username: str
Your desired Player Name.
timeout: float
Optional timout for how long to wait before abandoning connection
Returns
-------
player_number: int
The assigned player number in the global game.
Raises
------
ConnectionError
If we could not connect to the game server successfully.
Notes
-----
This is your absolute player number that will be used for interpreting the full server state
and the winners after the end of the game.
"""
# Add this player to the game.
self.pull_dataframe()
self._player: Player = Player(name=username, auth_key=self._auth_key)
self.player_df.add_one(Player, self._player)
self.push_dataframe()
# Check to see if adding our Player object to the dataframe worked.
self.pull_dataframe()
if timeout:
self.fr.start_timeout(timeout)
while True:
if self.tick() and timeout:
self._player = None
raise ConnectionError("Timed out connecting to server.")
# The server should remove our player object if it doesnt want us to connect.
if self.player_df.read_one(Player, self._player.pid) is None:
self._player = None
raise ConnectionError("Server rejected adding your player.")
# If the game start timed out, then we break out now.
if self._server_state.terminal:
self._player = None
raise ConnectionError("Server could not successfully start game.")
# If we have been given a player number, it means the server is ready for a game to start.
if self._player.number >= 0:
break
self.pull_dataframe()
# Connect to observation dataframe, and get the initial observation.
assert self._player.observation_port > 0, "Server failed to create an observation dataframe."
self.observation_df = Dataframe("{}_observation_df".format(self._player.name),
[self._observation_class],
details=(self._host, self._player.observation_port))
# Receive the first observation and ensure correct game
self.pull_dataframe()
self._observation = self.observation_df.read_all(self._observation_class)[0]
assert all([hasattr(self._observation, dimension) for dimension in self.dimensions]), \
"Mismatch in game between server and client."
# Let the server know that we are ready to start.
self._player.ready_for_start = True
self.push_dataframe()
self.connected = True
return self._player.number
[docs] def wait_for_turn(self, timeout: Optional[float] = None):
""" Block until it is your turn. This is usually only used in the beginning of the game.
Parameters
----------
timeout: float
An optional hard timeout on waiting for the game to start.
Returns
-------
observation:
The player's observation once its turn has arrived.
"""
assert self.connected, "Not connected to game server."
if timeout:
self.fr.start_timeout(timeout)
while not self._player.turn:
if self.terminal:
raise ConnectionError("Server finished game while we were waiting.")
if self.tick() and timeout:
raise ConnectionError("Timed out waiting for a game.")
self.pull_dataframe()
return self.observation
[docs] def wait_for_start(self, timeout: Optional[float] = None):
""" Secondary name for to be clearer when starting game. """
self.wait_for_turn(timeout)
[docs] def valid_actions(self):
""" Get a list of all valid moves for the current state.
Raises
------
NotImplementedError
If the client environment does not have access to all of your available moves.
Returns
-------
moves: list[str]
"""
if self._server_environment is not None:
return self._server_environment.valid_actions(self.full_state, self._player.number)
else:
raise NotImplementedError("No valid_action is implemented in this client and "
"we do not have access to the full server environment")
[docs] def step(self, action: str) -> Tuple[Dict[str, np.ndarray], float, bool, Optional[List[int]]]:
""" Perform an action and send it to the server. This wil block until it is your turn again.
Parameters
----------
action: str
Your action string.
Returns
-------
observation : Dict[str, np.ndarray]
The new observation dictionary for the new state.
reward : float
The reward for the previous action.
terminal : bool
Whether or not the game has ended.
winners : Optional[List[int]]
If terminal is true, this will be a list of the player numbers that have won
If terminal is false, this will be None
"""
if not self.terminal:
self._player.action = action
self._player.ready_for_action_to_be_taken = True
self.push_dataframe()
while not self._player.turn or self._player.ready_for_action_to_be_taken:
self.tick()
self.pull_dataframe()
reward = self._player.reward_from_last_turn
terminal = self.terminal
winners = None
if terminal:
winners = dill.loads(self._server_state.winners)
self._player.acknowledges_game_over = True
self.push_dataframe()
return self.observation, reward, terminal, winners