Source code for rog_rl.model

from mesa import Model
from mesa.space import SingleGrid
from mesa.datacollection import DataCollector

import numpy as np

from rog_rl.agent import DiseaseSimAgent
from rog_rl.disease_planner import SEIRDiseasePlanner
from rog_rl.scheduler import CustomScheduler
from rog_rl.agent_state import AgentState
from rog_rl.vaccination_response import VaccinationResponse
from rog_rl.contact_network import ContactNetwork


[docs]class DiseaseSimModel(Model): """ The model class holds the model-level attributes, manages the agents, and generally handles the global level of our model. There is only one model-level parameter: how many agents the model contains. When a new model is started, we want it to populate itself with the given number of agents. The scheduler is a special model component which controls the order in which agents are activated. """ def __init__( self, width=50, height=50, population_density=0.75, vaccine_density=0, initial_infection_fraction=0.1, initial_vaccination_fraction=0.00, prob_infection=0.2, prob_agent_movement=0.0, disease_planner_config={ "latent_period_mu": 2 * 4, "latent_period_sigma": 0, "incubation_period_mu": 5 * 4, "incubation_period_sigma": 0, "recovery_period_mu": 14 * 4, "recovery_period_sigma": 0, }, max_timesteps=200, early_stopping_patience=14, toric=True, seed=None ): super().__init__() self.width = width self.height = height # fraction of the whole grid that is initiailized with agents self.population_density = population_density self.vaccine_density = vaccine_density self.n_agents = False self.n_vaccines = False self.initial_infection_fraction = initial_infection_fraction self.initial_vaccination_fraction = initial_vaccination_fraction self.prob_infection = prob_infection self.prob_agent_movement = prob_agent_movement self.disease_planner_config = disease_planner_config self.max_timesteps = max_timesteps self.early_stopping_patience = early_stopping_patience self.toric = toric self.seed = seed self.initialize_observation() self.initialize_disease_planner() self.initialize_scheduler() self.initialize_grid() self.initialize_contact_network() self.initialize_agents( infection_fraction=self.initial_infection_fraction, vaccination_fraction=self.initial_vaccination_fraction ) self.initialize_datacollector() self.running = True self.datacollector.collect(self) ########################################################################### ########################################################################### # Setup Initialization Helper Functions ###########################################################################
[docs] def initialize_observation(self): """ Observation is a nd-array of shape (width, height, num_states) where each AgentState will be marked in a separate challenge for each of the cells """ self.observation = np.zeros((self.width, self.height, len(AgentState)))
[docs] def initialize_disease_planner(self): """ Initializes a disease planner that the Agents can use to "schedule" infection progressions """ self.disease_planner = SEIRDiseasePlanner( latent_period_mu=self.disease_planner_config["latent_period_mu"], latent_period_sigma=self.disease_planner_config["latent_period_sigma"], # noqa incubation_period_mu=self.disease_planner_config["incubation_period_mu"], # noqa incubation_period_sigma=self.disease_planner_config["incubation_period_sigma"], # noqa recovery_period_mu=self.disease_planner_config["recovery_period_mu"], # noqa recovery_period_sigma=self.disease_planner_config["recovery_period_sigma"] # noqa )
[docs] def initialize_scheduler(self): """ Initializes the scheduler """ self.schedule = CustomScheduler(self)
[docs] def initialize_grid(self): """ Initializes the initial Grid """ self.grid = SingleGrid( width=self.width, height=self.height, torus=self.toric)
[docs] def initialize_contact_network(self): """ Initializes the contact network """ self.contact_network = ContactNetwork()
[docs] def initialize_agents(self, infection_fraction, vaccination_fraction): """ Intializes the intial agents on the grid """ assert 0 < self.population_density <= 1, \ "population_density should be between (0, 1]" # Assess the actual population self.n_agents = int(self.width * self.height * self.population_density) # Assess the available number of vaccines self.n_vaccines = int(self.n_agents * self.vaccine_density) # Assess the number of agents that # have to be infected (the seed infection) number_of_agents_to_infect = int(infection_fraction * self.n_agents) number_of_agents_to_vaccinate = int( vaccination_fraction * self.n_agents) # Assess the maximum number of vaccines # available in the whole simulation self.max_vaccines = self.n_vaccines + number_of_agents_to_vaccinate for i in range(self.n_agents): agent = DiseaseSimAgent( unique_id=i, model=self, prob_agent_movement=self.prob_agent_movement ) self.schedule.add(agent) self.grid.position_agent(agent, x="random", y="random") # Update model observation # TODO- This has to be refactored to avoid repitition agent_x, agent_y = agent.pos self.observation[agent_x, agent_y, agent.state.value] = 1 # Seed the infection in a fraction of the agents infection_condition = i < number_of_agents_to_infect if infection_condition: agent.trigger_infection(prob_infection=1.0) # Seed the vaccination in a fraction of the agents vaccination_condition = ( i >= number_of_agents_to_infect and i < (number_of_agents_to_infect + number_of_agents_to_vaccinate)) # noqa if vaccination_condition: agent.set_state(AgentState.VACCINATED)
[docs] def initialize_datacollector(self): """ Setup the initial datacollector """ self.datacollector = DataCollector( model_reporters={ "Susceptible": lambda m: m.get_population_fraction_by_state(AgentState.SUSCEPTIBLE), # noqa "Exposed": lambda m: m.get_population_fraction_by_state(AgentState.EXPOSED), # noqa "Infectious": lambda m: m.get_population_fraction_by_state(AgentState.INFECTIOUS), # noqa "Symptomatic": lambda m: m.get_population_fraction_by_state(AgentState.SYMPTOMATIC), # noqa "Recovered": lambda m: m.get_population_fraction_by_state(AgentState.RECOVERED), # noqa "Vaccinated": lambda m: m.get_population_fraction_by_state(AgentState.VACCINATED), # noqa "R0/10": lambda m: m.contact_network.compute_R0()/10.0 } )
########################################################################### ########################################################################### # State Aggregation # - Functions for easy access/aggregation of simulation wide state ###########################################################################
[docs] def get_observation(self): # assert self.observation.sum(axis=-1).max() <= 1.0 # Assertion disabled for perf reasons return self.observation
########################################################################### ########################################################################### # Scheduler # - Functions for easy access to scheduler ###########################################################################
[docs] def get_scheduler(self): return self.schedule
[docs] def get_population_fraction_by_state(self, state: AgentState): return self.schedule.get_agent_fraction_by_state(state)
[docs] def is_running(self): return self.running
########################################################################### ########################################################################### # Actions # - Functions for actions that can be performed on the model ###########################################################################
[docs] def step(self): """ A model step. Used for collecting data and advancing the schedule """ self.propagate_infections() self.datacollector.collect(self) self.schedule.step() self.simulation_completion_checks()
[docs] def vaccinate_cell(self, cell_x, cell_y): """ Vaccinates an agent at cell_x, cell_y, if present Response with : (is_vaccination_successful, vaccination_response) of types (boolean, VaccinationResponse) """ # Case 0 : No vaccines left if self.n_vaccines <= 0: return False, VaccinationResponse.AGENT_VACCINES_EXHAUSTED self.n_vaccines -= 1 # Case 1 : Cell is empty if self.grid.is_cell_empty((cell_x, cell_y)): return False, VaccinationResponse.CELL_EMPTY agent = self.grid[cell_x][cell_y] if agent.state == AgentState.SUSCEPTIBLE: # Case 2 : Agent is susceptible, and can be vaccinated agent.set_state(AgentState.VACCINATED) return True, VaccinationResponse.VACCINATION_SUCCESS elif agent.state == AgentState.EXPOSED: # Case 3 : Agent is already exposed, and its a waste of vaccination return False, VaccinationResponse.AGENT_EXPOSED elif agent.state == AgentState.INFECTIOUS: # Case 4 : Agent is already infectious, # and its a waste of vaccination return False, VaccinationResponse.AGENT_INFECTIOUS elif agent.state == AgentState.SYMPTOMATIC: # Case 5 : Agent is already Symptomatic, # and its a waste of vaccination return False, VaccinationResponse.AGENT_SYMPTOMATIC elif agent.state == AgentState.RECOVERED: # Case 6 : Agent is already Recovered, # and its a waste of vaccination return False, VaccinationResponse.AGENT_RECOVERED elif agent.state == AgentState.VACCINATED: # Case 7 : Agent is already Vaccination, # and its a waste of vaccination return False, VaccinationResponse.AGENT_VACCINATED raise NotImplementedError()
########################################################################### ########################################################################### # Misc ###########################################################################
[docs] def simulation_completion_checks(self): """ Simulation is complete if : - if the timesteps have exceeded the number of max_timesteps or - the fraction of susceptible population is <= 0 or - the fraction of susceptible population has not changed since the last N timesteps """ if self.schedule.steps > self.max_timesteps - 1: self.running = False return susceptible_population = self.get_population_fraction_by_state( AgentState.SUSCEPTIBLE) if susceptible_population <= 0: self.running = False return if self.schedule.steps > self.early_stopping_patience: last_N_susceptible_population = \ self.datacollector.model_vars["Susceptible"][-1 * self.early_stopping_patience:] # noqa if len(set(last_N_susceptible_population)) == 1: self.running = False return
[docs] def tick(self): """ a mirror function for the internal step function to help avoid confusion in the RL codebases (with the RL step) """ self.step()
[docs] def propagate_infections(self): """ Propagates infection during a single simulation step """ valid_infectious_agents = [] valid_infectious_agents += self.schedule.get_agents_by_state( AgentState.INFECTIOUS) valid_infectious_agents += self.schedule.get_agents_by_state( AgentState.SYMPTOMATIC) for _infectious_agent in valid_infectious_agents: target_candidates = self.grid.get_neighbors( pos=_infectious_agent.pos, moore=True, include_center=False, radius=1 ) for _target_candidate in target_candidates: if _target_candidate.state == AgentState.SUSCEPTIBLE: was_infection_successful =\ _target_candidate.trigger_infection( prob_infection=self.prob_infection) if was_infection_successful: # Register infection in the contact network self.contact_network.register_infection_spread( _infectious_agent, _target_candidate )
if __name__ == "__main__": model = DiseaseSimModel( width=50, height=50, population_density=0.99, vaccine_density=0.0, initial_infection_fraction=0.99, initial_vaccination_fraction=0.0, prob_infection=1.0, prob_agent_movement=0.0, disease_planner_config={ "latent_period_mu": 2 * 4, "latent_period_sigma": 0, "incubation_period_mu": 5 * 4, "incubation_period_sigma": 0, "recovery_period_mu": 14 * 4, "recovery_period_sigma": 0, }, max_timesteps=5, early_stopping_patience=14, toric=True) import time per_step_times = [] for k in range(100): _time = time.time() model.step() per_step_times.append(time.time() - _time) _obs = model.get_observation() # print(model.get_population_fraction_by_state(AgentState.SUSCEPTIBLE)) # Random Vaccinations # random_x = model.random.choice(range(50)) # random_y = model.random.choice(range(50)) # print(model.vaccinate_cell(random_x, random_y)) # print(per_step_times[-1]) # print(model.datacollector.get_model_vars_dataframe()) # print("S", model.schedule.get_agent_count_by_state(AgentState.SUSCEPTIBLE)) # noqa # print("E", model.schedule.get_agent_count_by_state(AgentState.EXPOSED)) # noqa # print("I", model.schedule.get_agent_count_by_state(AgentState.INFECTIOUS)) # noqa # print("R", model.schedule.get_agent_count_by_state(AgentState.RECOVERED)) # noqa # print(viz.render()) per_step_times = np.array(per_step_times) print("Per Step Time : {} += {}", per_step_times.mean(), per_step_times.std()) # noqa