Source code for gym_nethack.policies.core
import os
from gym_nethack.misc import verboseprint
from gym_nethack.fileio import get_dir_for_params
[docs]class Policy(object):
"""Standard policy class taken from Keras-RL with a few extensions."""
name = 'unnamed'
def __init__(self, name='obsolete'):
self.name = name
def _set_agent(self, agent):
self.agent = agent
@property
def metrics_names(self):
return []
@property
def metrics(self):
return []
[docs] def select_action(self, **kwargs):
raise NotImplementedError
[docs] def get_config(self):
return {}
[docs] def set_config(self):
pass
[docs]class ParameterizedPolicy(Policy):
"""Extension of policy class that allows for grid-search on specified parameters."""
def __init__(self):
"""Initialize a policy that has parameters which can be modified (e.g., through grid search)."""
self.cur_combo = -1
self.num_games_this_combo = 0
[docs] def set_config(self, grid_search=False, top_models=False, num_episodes_per_combo=200, proc_id=0, num_procs=1, param_combos=None, param_abbrvs=None):
"""Set config.
Args:
grid_search: whether to change parameters every certain number of episodes.
top_models: whether to load from a text file and use the specified param combos inside. (Must have grid_search=True)
num_episodes_per_combo: if grid search, number of episodes per each combination of alg. parameters.
proc_id: if grid search, process ID of this environment, to be matched with the argument passed to the daemon launching script.
num_procs: if grid search, number of processes that will be running in parallel
param_combos: list of lists of parameter combinations
param_abbrvs: abbreviated parameter names (for directory name)
"""
self.grid_search = grid_search
assert not top_models or grid_search
if grid_search:
self.num_episodes_per_combo = num_episodes_per_combo
self.param_abbrvs = param_abbrvs
if os.path.isfile('combos_to_try.txt'):
param_combos = read_list('combos_to_try')
elif top_models and os.path.isfile(self.base_dir + 'top_models.txt'):
param_combos = [get_params_for_dir(dirname) for dirname in read_list(self.env.basedir + 'top_models')]
else:
num_combos_per_proc = max(len(param_combos) // num_procs, 1)
param_combos_per_proc = [param_combos[i:i + num_combos_per_proc] for i in range(0, len(param_combos), num_combos_per_proc)]
self.combos_to_set = param_combos_per_proc[proc_id if num_procs > 1 else 0]
verboseprint(param_combos_per_proc)
else:
self.set_params(self.get_default_params())
[docs] def reset(self):
"""Called on starting a new episode."""
self.switch_encounter()
[docs] def switch_encounter(self):
"""Alter alg. parameters if using grid search."""
if not self.grid_search:
return
elif self.combos_to_set is not None: # initial setup
self.set_combos(self.combos_to_set)
self.combos_to_set = None
elif self.cur_combo > -1 and self.num_games_this_combo < self.num_episodes_per_combo:
verboseprint("Not finished with current param combo yet")
return
while True:
if self.num_games_this_combo > 0:
self.env.save_records() # save current records to the current directroy
self.cur_combo += 1
if self.cur_combo >= len(self.param_combos):
verboseprint("Past max combo, going back to combo 0")
self.cur_combo = 0
print("Combo", self.cur_combo, "/", len(self.param_combos))
cur_params = self.param_combos[self.cur_combo]
self.set_params(cur_params)
self.env.savedir = self.env.basedir + get_dir_for_params(cur_params, self.param_abbrvs)
self.env.load_records() # load any existing records from the new directory
self.num_games_this_combo = 0
self.env.total_num_games += len(self.env.records['expl'])
break
verboseprint("Cur params:", self.env.savedir)
[docs] def end_episode(self):
"""Record new episode ended."""
self.num_games_this_combo += 1
[docs] def set_combos(self, combos):
"""Update list of parameter combinations to try.
Args:
combos: list of combinations to use"""
self.param_combos = combos
self.env.max_num_episodes = self.num_episodes_per_combo * len(combos)
[docs] def get_default_params(self):
"""Get the default parameters for the policy."""
return []
[docs] def set_params(self, params):
"""Set the current parameters for the policy.
Args:
params: policy parameters"""
pass