@ -1,6 +1,7 @@
import logging
from collections import OrderedDict , defaultdict
from collections . abc import Mapping , Set
from collections . abc import MutableMapping , Mapping , Set
from abc import ABCMeta
from copy import deepcopy
from functools import partial , wraps
from itertools import islice , chain
@ -10,6 +11,8 @@ import networkx as nx
from mesa import Agent as MesaAgent
from typing import Dict , List
from random import shuffle
from . . import serialization , utils , time , config
@ -25,7 +28,7 @@ IGNORED_FIELDS = ('model', 'logger')
class DeadAgent ( Exception ) :
pass
class BaseAgent ( MesaAgent ):
class BaseAgent ( MesaAgent , MutableMapping ):
"""
A special type of Mesa Agent that :
@ -50,8 +53,10 @@ class BaseAgent(MesaAgent):
# Initialize agent parameters
if isinstance ( unique_id , MesaAgent ) :
raise Exception ( )
assert isinstance ( unique_id , int )
super ( ) . __init__ ( unique_id = unique_id , model = model )
self . name = name or ' {} [ {} ] ' . format ( type ( self ) . __name__ , self . unique_id )
self . name = str ( name ) if name else ' {} [ {} ] ' . format ( type ( self ) . __name__ , self . unique_id )
self . _neighbors = None
self . alive = True
@ -120,6 +125,12 @@ class BaseAgent(MesaAgent):
def __setitem__ ( self , key , value ) :
setattr ( self , key , value )
def __len__ ( self ) :
return sum ( 1 for n in self . keys ( ) )
def __iter__ ( self ) :
return self . items ( )
def keys ( self ) :
return ( k for k in self . __dict__ if k [ 0 ] != ' _ ' )
@ -284,7 +295,7 @@ def default_state(func):
return func
class MetaFSM ( type ) :
class MetaFSM ( ABCMeta ) :
def __init__ ( cls , name , bases , nmspc ) :
super ( MetaFSM , cls ) . __init__ ( name , bases , nmspc )
states = { }
@ -486,14 +497,15 @@ def _definition_to_dict(definition, size=None, default_state=None):
distro = sorted ( [ item for item in definition if ' weight ' in item ] )
ix = 0
id = 0
def init_agent ( item , id = ix ) :
while id in agents :
id + = 1
agent = remaining [ id ]
agent [ ' state ' ] . update ( copy ( item . get ( ' state ' , { } ) ) )
agents [ id] = agent
agents [ agent . unique_ id] = agent
del remaining [ id ]
return agent
@ -554,7 +566,7 @@ class AgentView(Mapping, Set):
return sum ( len ( x ) for x in self . _agents . values ( ) )
def __iter__ ( self ) :
return iter ( chain . from_iterable ( g . values ( ) for g in self . _agents . values ( ) ) )
yield from iter ( chain . from_iterable ( g . values ( ) for g in self . _agents . values ( ) ) )
def __getitem__ ( self , agent_id ) :
if isinstance ( agent_id , slice ) :
@ -564,54 +576,71 @@ class AgentView(Mapping, Set):
return group [ agent_id ]
raise ValueError ( f " Agent { agent_id } not found " )
def filter ( self , ids = None , groups = None , state_id = None , agent_type = None , ignore = None , iterator = False , * * kwargs ) :
def filter ( self , * group_ids , * * kwargs ) :
yield from filter_groups ( self . _agents , group_ids = group_ids , * * kwargs )
if state_id is not None and not isinstance ( state_id , ( tuple , list ) ) :
state_id = tuple ( [ state_id ] )
def __call__ ( self , * args , * * kwargs ) :
return list ( self . filter ( * args , * * kwargs ) )
agents = self . _agents
def __contains__ ( self , agent_id ) :
return any ( agent_id in g for g in self . _agents )
if groups :
agents = { ( k , v ) for ( k , v ) in agents . items ( ) if k in groups }
def __str__ ( self ) :
return str ( list ( a . id for a in self ) )
if agent_type is not None :
try :
agent_type = tuple ( agent_type )
except TypeError :
agent_type = tuple ( [ agent_type ] )
def __repr__ ( self ) :
return f " { self . __class__ . __name__ } ( { self } ) "
if ids :
agents = ( v [ aid ] for v in agents . values ( ) for aid in ids if aid in v )
else :
agents = ( a for v in agents . values ( ) for a in v . values ( ) )
f = agents
if ignore :
f = filter ( lambda x : x not in ignore , f )
def filter_groups ( groups , group_ids = None , * * kwargs ) :
assert isinstance ( groups , dict )
if group_ids :
groups = list ( groups [ g ] for g in group_ids if g in groups )
else :
groups = list ( groups . values ( ) )
agents = chain . from_iterable ( filter_group ( g , * * kwargs ) for g in groups )
if state_id is not None :
f = filter ( lambda agent : agent . get ( ' state_id ' , None ) in state_id , f )
yield from agents
if agent_type is not None :
f = filter ( lambda agent : isinstance ( agent , agent_type ) , f )
for k , v in kwargs . items ( ) :
f = filter ( lambda agent : agent . state . get ( k , None ) == v , f )
if iterator :
return f
return list ( f )
def filter_group ( group , ids = None , state_id = None , agent_type = None , ignore = None , state = None , * * kwargs ) :
'''
Filter agents given as a dict , by the criteria given as arguments ( e . g . , certain type or state id ) .
'''
assert isinstance ( group , dict )
def __call__ ( self , * args , * * kwargs ) :
return self . filter ( * args , * * kwargs )
if state_id is not None and not isinstance ( state_id , ( tuple , list ) ) :
state_id = tuple ( [ state_id ] )
def __contains__ ( self , agent_id ) :
return any ( agent_id in g for g in self . _agents )
if agent_type is not None :
try :
agent_type = tuple ( agent_type )
except TypeError :
agent_type = tuple ( [ agent_type ] )
def __str__ ( self ) :
return str ( list ( a . id for a in self ) )
if ids :
agents = ( v [ aid ] for aid in ids if aid in group )
else :
agents = ( a for a in group . values ( ) )
def __repr__ ( self ) :
return f " { self . __class__ . __name__ } ( { self } ) "
f = agents
if ignore :
f = filter ( lambda x : x not in ignore , f )
if state_id is not None :
f = filter ( lambda agent : agent . get ( ' state_id ' , None ) in state_id , f )
if agent_type is not None :
f = filter ( lambda agent : isinstance ( agent , agent_type ) , f )
state = state or dict ( )
state . update ( kwargs )
for k , v in state . items ( ) :
f = filter ( lambda agent : agent . state . get ( k , None ) == v , f )
yield from f
def from_config ( cfg : Dict [ str , config . AgentConfig ] , env ) :
@ -632,10 +661,22 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
agents = _from_fixed ( cfg . fixed , topology = cfg . topology , default = default , env = env )
if cfg . distribution :
n = cfg . n or len ( env . topologies [ cfg . topology ] )
agents . update ( _from_distro ( cfg . distribution , n - len ( agents ) ,
target = n - len ( agents )
agents . update ( _from_distro ( cfg . distribution , target ,
topology = cfg . topology or default . topology ,
default = default ,
env = env ) )
assert len ( agents ) == n
if cfg . override :
for attrs in cfg . override :
if attrs . filter :
filtered = list ( filter_group ( agents , * * attrs . filter ) )
else :
filtered = list ( agents )
for agent in random . sample ( filtered , attrs . n ) :
agent . state . update ( attrs . state )
return agents
@ -650,10 +691,11 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
cls = serialization . deserialize ( fixed . agent_class or default . agent_class )
state = fixed . state . copy ( )
state . update ( default . state )
agents [ agent_id ] = cls ( unique_id = agent_id ,
model = env ,
graph_name = fixed . topology or topology or default . topology ,
* * state )
agent = cls ( unique_id = agent_id ,
model = env ,
graph_name = fixed . topology or topology or default . topology ,
* * state )
agents [ agent . unique_id ] = agent
return agents
@ -671,31 +713,40 @@ def _from_distro(distro: List[config.AgentDistro],
raise ValueError ( ' You must provide a total number of agents, or the number of each type ' )
n = sum ( dist . n for dist in distro )
weights = list ( dist . weight if dist . weight is not None else 1 for dist in distro )
minw = min ( weights )
norm = list ( weight / minw for weight in weights )
total = sum ( norm )
chunk = n / / total
total = sum ( ( dist . weight if dist . weight is not None else 1 ) for dist in distro )
thres = { }
last = 0
for i in sorted ( distro , key = lambda x : x . weight ) :
# random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
cls = serialization . deserialize ( i . agent_class or default . agent_class )
thres [ ( last , last + i . weight / total ) ] = ( cls , i )
# Calculate how many times each has to appear
indices = list ( chain . from_iterable ( [ idx ] * int ( n * chunk ) for ( idx , n ) in enumerate ( norm ) ) )
# Complete with random agents following the original weight distribution
if len ( indices ) < n :
indices + = random . choices ( list ( range ( len ( distro ) ) ) , weights = [ d . weight for d in distro ] , k = n - len ( indices ) )
# Deserialize classes for efficiency
classes = list ( serialization . deserialize ( i . agent_class or default . agent_class ) for i in distro )
# Add them in random order
random . shuffle ( indices )
acc = 0
# using np.choice would be more efficient, but this allows us to use soil without
# numpy
for i in range ( n ) :
r = random . random ( )
for ( t , ( cls , d ) ) in thres . items ( ) :
if r > = t [ 0 ] and r < = t [ 1 ] :
agent_id = d . agent_id
if agent_id is None :
agent_id = env . next_id ( )
state = d . state . copy ( )
state . update ( default . state )
agents [ agent_id ] = cls ( unique_id = agent_id , model = env , graph_name = d . topology or topology or default . topology , * * state )
break
for idx in indices :
d = distro [ idx ]
cls = classes [ idx ]
agent_id = env . next_id ( )
state = d . state . copy ( )
state . update ( default . state )
agent = cls ( unique_id = agent_id , model = env , graph_name = d . topology or topology or default . topology , * * state )
assert agent . name is not None
assert agent . name != ' None '
assert agent . name
agents [ agent . unique_id ] = agent
return agents