Source code for rog_rl.scheduler

#!/usr/bin/env python

from mesa import Agent, Model
from mesa.time import RandomActivation
from rog_rl.agent_state import AgentState


[docs]class CustomScheduler(RandomActivation): def __init__(self, model: Model) -> None: super().__init__(model) self._agent_state_index = {} for state in AgentState: self._agent_state_index[state] = {}
[docs] def add(self, agent: Agent) -> None: self._agents[agent.unique_id] = agent self._agent_state_index[agent.state][agent.unique_id] = agent
[docs] def remove(self, agent: Agent) -> None: del self._agents[agent.unique_id] for state in AgentState: try: del self._agent_state_index[state][agent.unique_id] except KeyError: pass
[docs] def update_agent_state_in_registry( self, agent: Agent, previous_state: AgentState) -> None: del self._agent_state_index[previous_state][agent.unique_id] self._agent_state_index[agent.state][agent.unique_id] = agent # Update the Model Observation # self.model.observation[agent.pos] agent_x, agent_y = agent.pos self.model.observation[agent_x, agent_y, previous_state.value] = 0 self.model.observation[agent_x, agent_y, agent.state.value] = 1
[docs] def get_agents_by_state(self, state: AgentState): return list(self._agent_state_index[state].values())
[docs] def get_agent_count_by_state(self, state: AgentState) -> int: """ Returns the current number of agents in a particular state. """ return len(self._agent_state_index[state].keys())
[docs] def get_agent_fraction_by_state(self, state: AgentState) -> int: """ Returns the current number of agents in a particular state. """ _frac = \ len(self._agent_state_index[state].keys()) / self.get_agent_count() return _frac