[docs]classBaseReinforcementSpace(BaseSpace,ABC):"""Reinforcement Base Space class. Args: env: The `torchrl <https://pytorch.org/rl/>`_ environment to run the evaluation on. config """def__init__(self:"BaseReinforcementSpace",config:BaseSpaceConfig,env:EnvBase,)->None:super().__init__(config=config,num_pops=1,evaluates_on_gpu=False)self.env=env
[docs]@finaldefrun_pre_eval(self:"BaseReinforcementSpace",agent:BaseAgent,curr_gen:int,)->TensorDict:"""Resets/loads the environment before evaluation begins. Args: agent: The agent being evaluated. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. Returns: See :paramref:`run_post_eval.out`. """ifcurr_gen>1andagent.config.env_transfer:self.env=copy.deepcopy(agent.saved_env)returncopy.deepcopy(agent.saved_env_out)self.env.set_seed(seed=curr_gen)returnself.env.reset()
[docs]defenv_done_reset(self:"BaseReinforcementSpace",agent:BaseAgent,out:TensorDict,curr_gen:int,)->TensorDict|dict[str,bool]:"""Resets the agent/environment when the environment terminates. Args: agent: See :paramref:`pre_eval_reset.agent`. out: The latest environment output. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. Returns: See :paramref:`run_post_eval.out`. """# env,fit,env+fit,env+fit+mem: reset, mem,mem+fit: no resetifnot(agent.config.mem_transferor(agent.config.mem_transferandagent.config.fit_transfer)):agent.reset()ifagent.config.env_transfer:self.logged_score:float|None=agent.curr_episode_scoreagent.curr_episode_score=0agent.curr_episode_num_steps=0self.env.set_seed(seed=curr_gen)returnself.env.reset()returnout
[docs]@finaldefrun_post_eval(self:"BaseReinforcementSpace",agent:BaseAgent,out:TensorDict,curr_gen:int,)->None:"""Resets the agent & saves the environment post-evaluation. Args: agent: See :paramref:`pre_eval_reset.agent`. out: The latest environment output. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. """ifnotagent.config.mem_transfer:agent.reset()ifagent.config.env_transfer:agent.saved_env=copy.deepcopy(self.env)agent.saved_env_out=copy.deepcopy(out)ifnotagent.config.env_transfer:self.logged_score=agent.curr_eval_scoreifself.config.logging:gather(logged_score=self.logged_score,curr_gen=curr_gen,agent_total_num_steps=agent.total_num_steps,)
[docs]@finaldefevaluate(self:"BaseReinforcementSpace",agents:list[list[BaseAgent]],curr_gen:An[int,ge(1)],)->np.ndarray[np.float32,Any]:"""Evaluation function called once per generation. Args: agents: A 2D list containing the agent to evaluate. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. """agent=agents[0][0]agent.curr_eval_score=0agent.curr_eval_num_steps=0self.logged_score=Noneout=self.run_pre_eval(agent=agent,curr_gen=curr_gen)whilenotout["done"]:out=out.set(key="action",item=agent(x=out["observation"]))out=self.env.step(tensordict=out)["next"]agent.curr_eval_score+=float(out["reward"])agent.curr_eval_num_steps+=1agent.total_num_steps+=1ifagent.config.env_transfer:agent.curr_episode_score+=float(out["reward"])agent.curr_episode_num_steps+=1ifagent.config.fit_transfer:agent.continual_fitness+=float(out["reward"])ifout["done"]:out=self.env_done_reset(agent=agent,out=out,curr_gen=curr_gen,)ifagent.curr_eval_num_steps==self.config.eval_num_steps:breakself.run_post_eval(agent=agent,out=out,curr_gen=curr_gen)returnnp.array(((agent.continual_fitnessifagent.config.fit_transferelseagent.curr_eval_score),agent.curr_eval_num_steps,),)