@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from collections import OrderedDict , defaultdict
from collections . abc import MutableMapping , Mapping , Set
@ -5,9 +7,13 @@ from abc import ABCMeta
from copy import deepcopy , copy
from functools import partial , wraps
from itertools import islice , chain
import json
import inspect
import types
import textwrap
import networkx as nx
from typing import Any
from mesa import Agent as MesaAgent
from typing import Dict , List
@ -27,7 +33,31 @@ class DeadAgent(Exception):
pass
class BaseAgent ( MesaAgent , MutableMapping ) :
class MetaAgent ( ABCMeta ) :
def __new__ ( mcls , name , bases , namespace ) :
defaults = { }
# Re-use defaults from inherited classes
for i in bases :
if isinstance ( i , MetaAgent ) :
defaults . update ( i . _defaults )
new_nmspc = {
' _defaults ' : defaults ,
}
for attr , func in namespace . items ( ) :
if isinstance ( func , types . FunctionType ) or isinstance ( func , property ) or attr [ 0 ] == ' _ ' :
new_nmspc [ attr ] = func
elif attr == ' defaults ' :
defaults . update ( func )
else :
defaults [ attr ] = copy ( func )
return super ( ) . __new__ ( mcls = mcls , name = name , bases = bases , namespace = new_nmspc )
class BaseAgent ( MesaAgent , MutableMapping , metaclass = MetaAgent ) :
"""
A special type of Mesa Agent that :
@ -39,15 +69,12 @@ class BaseAgent(MesaAgent, MutableMapping):
Any attribute that is not preceded by an underscore ( ` _ ` ) will also be added to its state .
"""
defaults = { }
def __init__ ( self ,
unique_id ,
model ,
name = None ,
interval = None ,
* * kwargs
) :
* * kwargs ) :
# Check for REQUIRED arguments
# Initialize agent parameters
if isinstance ( unique_id , MesaAgent ) :
@ -58,15 +85,16 @@ class BaseAgent(MesaAgent, MutableMapping):
self . name = str ( name ) if name else ' {} [ {} ] ' . format ( type ( self ) . __name__ , self . unique_id )
self . _neighbors = None
self . alive = True
self . interval = interval or self . get ( ' interval ' , 1 )
self . logger = logging . getLogger ( self . model . id ) . getChild ( self . name )
logger = utils . logger . getChild ( getattr ( self . model , ' id ' , self . model ) ) . getChild ( self . name )
self . logger = logging . LoggerAdapter ( logger , { ' agent_name ' : self . name } )
if hasattr ( self , ' level ' ) :
self . logger . setLevel ( self . level )
for ( k , v ) in self . defaults . items ( ) :
for ( k , v ) in self . _defaults . items ( ) :
if not hasattr ( self , k ) or getattr ( self , k ) is None :
setattr ( self , k , deepcopy ( v ) )
@ -74,10 +102,6 @@ class BaseAgent(MesaAgent, MutableMapping):
setattr ( self , k , v )
for ( k , v ) in getattr ( self , ' defaults ' , { } ) . items ( ) :
if not hasattr ( self , k ) or getattr ( self , k ) is None :
setattr ( self , k , v )
def __hash__ ( self ) :
return hash ( self . unique_id )
@ -89,14 +113,6 @@ class BaseAgent(MesaAgent, MutableMapping):
def id ( self ) :
return self . unique_id
@property
def env ( self ) :
return self . model
@env.setter
def env ( self , model ) :
self . model = model
@property
def state ( self ) :
'''
@ -108,19 +124,16 @@ class BaseAgent(MesaAgent, MutableMapping):
@state.setter
def state ( self , value ) :
if not value :
return
for k , v in value . items ( ) :
self [ k ] = v
@property
def environment_params ( self ) :
return self . model . environment_params
@environment_params.setter
def environment_params ( self , value ) :
self . model . environment_params = value
def __getitem__ ( self , key ) :
return getattr ( self , key )
try :
return getattr ( self , key )
except AttributeError :
raise KeyError ( f ' key { key } not found in agent ' )
def __delitem__ ( self , key ) :
return delattr ( self , key )
@ -138,11 +151,15 @@ class BaseAgent(MesaAgent, MutableMapping):
return self . items ( )
def keys ( self ) :
return ( k for k in self . __dict__ if k [ 0 ] != ' _ ' )
def items ( self ) :
return ( ( k , v ) for ( k , v ) in self . __dict__ . items ( ) if k [ 0 ] != ' _ ' )
return ( k for k in self . __dict__ if k [ 0 ] != ' _ ' and k not in IGNORED_FIELDS )
def items ( self , keys = None , skip = None ) :
keys = keys if keys is not None else self . keys ( )
it = ( ( k , self . get ( k , None ) ) for k in keys )
if skip :
return filter ( lambda x : x [ 0 ] not in skip , it )
return it
def get ( self , key , default = None ) :
return self [ key ] if key in self else default
@ -154,11 +171,9 @@ class BaseAgent(MesaAgent, MutableMapping):
# No environment
return None
def die ( self , remove = False ):
self . info ( f ' agent { self . unique_id } is dying' )
def die ( self ):
self . info ( f ' agent dying' )
self . alive = False
if remove :
self . remove_node ( self . id )
return time . NEVER
def step ( self ) :
@ -170,7 +185,7 @@ class BaseAgent(MesaAgent, MutableMapping):
if not self . logger . isEnabledFor ( level ) :
return
message = message + " " . join ( str ( i ) for i in args )
message = " @{:>3 }: {} " . format ( self . now , message )
message = " [@{:>4} ] \t {:>10 }: {} " . format ( self . now , repr ( self ) , message )
for k , v in kwargs :
message + = " {k} = {v} " . format ( k , v )
extra = { }
@ -179,33 +194,48 @@ class BaseAgent(MesaAgent, MutableMapping):
extra [ ' agent_name ' ] = self . name
return self . logger . log ( level , message , extra = extra )
def debug ( self , * args , * * kwargs ) :
return self . log ( * args , level = logging . DEBUG , * * kwargs )
def info ( self , * args , * * kwargs ) :
return self . log ( * args , level = logging . INFO , * * kwargs )
# Alias
# Agent = BaseAgent
def count_agents ( self , * * kwargs ) :
return len ( list ( self . get_agents ( * * kwargs ) ) )
class NetworkAgent ( BaseAgent ) :
def get_agents ( self , * args , * * kwargs ) :
it = self . iter_agents ( * args , * * kwargs )
return list ( it )
@property
def topology ( self ) :
return self . env . topology_for ( self . unique_id )
def iter_agents ( self , * args , * * kwargs ) :
yield from filter_agents ( self . model . schedule . _agents , * args , * * kwargs )
@property
def node_id ( self ) :
return self . env . node_id_for ( self . unique_id )
def __str__ ( self ) :
return self . to_str ( )
def to_str ( self , keys = None , skip = None , pretty = False ) :
content = dict ( self . items ( keys = keys ) )
if pretty and content :
d = content
content = ' \n '
for k , v in d . items ( ) :
content + = f ' - { k } : { v } \n '
content = textwrap . indent ( content , ' ' )
return f " { repr ( self ) } { content } "
@property
def G ( self ) :
return self . model . topologies [ self . _topology ]
def __repr__ ( self ) :
return f " { self . __class__ . __name__ } ( { self . unique_id } ) "
def count_agents ( self , * * kwargs ) :
return len ( list ( self . get_agents ( * * kwargs ) ) )
class NetworkAgent ( BaseAgent ) :
def __init__ ( self , * args , topology , node_id , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . topology = topology
self . node_id = node_id
self . G = self . model . topologies [ topology ]
assert self . G
def count_neighboring_agents ( self , state_id = None , * * kwargs ) :
return len ( self . get_neighboring_agents ( state_id = state_id , * * kwargs ) )
@ -213,57 +243,47 @@ class NetworkAgent(BaseAgent):
def get_neighboring_agents ( self , state_id = None , * * kwargs ) :
return self . get_agents ( limit_neighbors = True , state_id = state_id , * * kwargs )
def get_agents ( self , * args , limit = None , * * kwargs ) :
it = self . iter_agents ( * args , * * kwargs )
if limit is not None :
it = islice ( it , limit )
return list ( it )
def iter_agents ( self , unique_id = None , limit_neighbors = False , * * kwargs ) :
unique_ids = None
if isinstance ( unique_id , list ) :
unique_ids = set ( unique_id )
elif unique_id is not None :
unique_ids = set ( [ unique_id , ] )
if limit_neighbors :
unique_id = [ self . topology . nodes [ node ] [ ' agent_id ' ] for node in self . topology . neighbors ( self . node_id ) ]
if not unique_id :
neighbor_ids = set ( )
for node_id in self . G . neighbors ( self . node_id ) :
if self . G . nodes [ node_id ] . get ( ' agent_id ' ) is not None :
neighbor_ids . add ( node_id )
if unique_ids :
unique_ids = unique_ids & neighbor_ids
else :
unique_ids = neighbor_ids
if not unique_ids :
return
yield from self . model . agents ( unique_id = unique_id , * * kwargs )
unique_ids = list ( unique_ids )
yield from super ( ) . iter_agents ( unique_id = unique_ids , * * kwargs )
def subgraph ( self , center = True , * * kwargs ) :
include = [ self ] if center else [ ]
G = self . topology . subgraph ( n . node_id for n in list ( self . get_agents ( * * kwargs ) + include ) )
G = self . G . subgraph ( n . node_id for n in list ( self . get_agents ( * * kwargs ) + include ) )
return G
def remove_node ( self , unique_id ):
self . topology. remove_node ( uniqu e_id)
def remove_node ( self ):
self . G. remove_node ( self . nod e_id)
def add_edge ( self , other , edge_attr_dict = None , * edge_attrs ) :
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
if self . unique_id not in self . topology . nodes ( data = False ) :
if self . node_id not in self . G . nodes ( data = False ) :
raise ValueError ( ' {} not in list of existing agents in the network ' . format ( self . unique_id ) )
if other . unique_id not in self . topology . nodes ( data = False ) :
if other . node_id not in self . G . nodes ( data = False ) :
raise ValueError ( ' {} not in list of existing agents in the network ' . format ( other ) )
self . topology. add_edge ( self . unique_id , other . uniqu e_id, edge_attr_dict = edge_attr_dict , * edge_attrs )
self . G. add_edge ( self . node_id , other . nod e_id, edge_attr_dict = edge_attr_dict , * edge_attrs )
def ego_search ( self , steps = 1 , center = False , node = None , * * kwargs ) :
''' Get a list of nodes in the ego network of *node* of radius *steps* '''
node = as_node ( node if node is not None else self )
G = self . subgraph ( * * kwargs )
return nx . ego_graph ( G , node , center = center , radius = steps ) . nodes ( )
def degree ( self , node , force = False ) :
node = as_node ( node )
if force or ( not hasattr ( self . model , ' _degree ' ) ) or getattr ( self . model , ' _last_step ' , 0 ) < self . now :
self . model . _degree = nx . degree_centrality ( self . topology )
self . model . _last_step = self . now
return self . model . _degree [ node ]
def betweenness ( self , node , force = False ) :
node = as_node ( node )
if force or ( not hasattr ( self . model , ' _betweenness ' ) ) or getattr ( self . model , ' _last_step ' , 0 ) < self . now :
self . model . _betweenness = nx . betweenness_centrality ( self . topology )
self . model . _last_step = self . now
return self . model . _betweenness [ node ]
def die ( self , remove = True ) :
if remove :
self . remove_node ( )
return super ( ) . die ( )
def state ( name = None ) :
@ -273,24 +293,29 @@ def state(name=None):
The default value for state_id is the current state id .
The default value for when is the interval defined in the environment .
'''
@wraps ( func )
def func_wrapper ( self ) :
next_state = func ( self )
when = None
if next_state is None :
return when
try :
next_state , when = next_state
except ( ValueError , TypeError ) :
pass
if next_state :
self . set_state ( next_state )
return when
func_wrapper . id = name or func . __name__
func_wrapper . is_default = False
return func_wrapper
if inspect . isgeneratorfunction ( func ) :
orig_func = func
@wraps ( func )
def func ( self ) :
while True :
if not self . _coroutine :
self . _coroutine = orig_func ( self )
try :
n = next ( self . _coroutine )
if n :
return None , n
return
except StopIteration as ex :
self . _coroutine = None
next_state = ex . value
if next_state is not None :
self . set_state ( next_state )
return next_state
func . id = name or func . __name__
func . is_default = False
return func
if callable ( name ) :
return decorator ( name )
@ -303,60 +328,84 @@ def default_state(func):
return func
class MetaFSM ( ABCMeta ) :
def __init__ ( cls , name , bases , nmspc ) :
super ( MetaFSM , cls ) . __init__ ( name , bases , nmspc )
class MetaFSM ( MetaAgent ) :
def __new__ ( mcls , name , bases , namespace ) :
states = { }
# Re-use states from inherited classes
default_state = None
for i in bases :
if isinstance ( i , MetaFSM ) :
for state_id , state in i . states. items ( ) :
for state_id , state in i . _ states. items ( ) :
if state . is_default :
default_state = state
states [ state_id ] = state
# Add new states
for name, func in nmspc . items ( ) :
for attr, func in namespace . items ( ) :
if hasattr ( func , ' id ' ) :
if func . is_default :
default_state = func
states [ func . id ] = func
cls . default_state = default_state
cls . states = states
namespace . update ( {
' _default_state ' : default_state ,
' _states ' : states ,
} )
return super ( MetaFSM , mcls ) . __new__ ( mcls = mcls , name = name , bases = bases , namespace = namespace )
class FSM ( BaseAgent , metaclass = MetaFSM ) :
def __init__ ( self , * args , * * kwargs ) :
super ( FSM , self ) . __init__ ( * args , * * kwargs )
if not hasattr ( self , ' state_id ' ) :
if not self . default_state:
if not self . _ default_state:
raise ValueError ( ' No default state specified for {} ' . format ( self . unique_id ) )
self . state_id = self . default_state. id
self . state_id = self . _ default_state. id
self . _coroutine = None
self . set_state ( self . state_id )
def step ( self ) :
self . debug ( f ' Agent { self . unique_id } @ state { self . state_id } ' )
interval = super ( ) . step ( )
if ' id ' not in self . state :
if self . default_state :
self . set_state ( self . default_state . id )
default_interval = super ( ) . step ( )
next_state = self . _states [ self . state_id ] ( self )
when = None
try :
next_state , * when = next_state
if not when :
when = None
elif len ( when ) == 1 :
when = when [ 0 ]
else :
raise Exception ( ' {} has no valid state id or default state ' . format ( self ) )
interval = self . states [ self . state_id ] ( self ) or interval
if not self . alive :
return time . NEVER
return interval
raise ValueError ( ' Too many values returned. Only state (and time) allowed ' )
except TypeError :
pass
if next_state is not None :
self . set_state ( next_state )
def set_state ( self , state ) :
return when or default_interval
def set_state ( self , state , when = None ) :
if hasattr ( state , ' id ' ) :
state = state . id
if state not in self . states:
if state not in self . _ states:
raise ValueError ( ' {} is not a valid state ' . format ( state ) )
self . state_id = state
if when is not None :
self . model . schedule . add ( self , when = when )
return state
def die ( self ) :
return self . dead , super ( ) . die ( )
@state
def dead ( self ) :
return self . die ( )
def prob ( prob , random ) :
'''
@ -476,81 +525,81 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
return deserialize_definition ( ind , * * kwargs )
def _agent_from_definition ( definition , random , value = - 1 , unique_id = None ) :
""" Used in the initialization of agents given an agent distribution. """
if value < 0 :
value = random . random ( )
for d in sorted ( definition , key = lambda x : x . get ( ' threshold ' ) ) :
threshold = d . get ( ' threshold ' , ( - 1 , - 1 ) )
# Check if the definition matches by id (first) or by threshold
if ( unique_id is not None and unique_id in d . get ( ' ids ' , [ ] ) ) or \
( value > = threshold [ 0 ] and value < threshold [ 1 ] ) :
state = { }
if ' state ' in d :
state = deepcopy ( d [ ' state ' ] )
return d [ ' agent_class ' ] , state
raise Exception ( ' Definition for value {} not found in: {} ' . format ( value , definition ) )
def _definition_to_dict ( definition , random , size = None , default_state = None ) :
state = default_state or { }
agents = { }
remaining = { }
if size :
for ix in range ( size ) :
remaining [ ix ] = copy ( state )
else :
remaining = defaultdict ( lambda x : copy ( state ) )
distro = sorted ( [ item for item in definition if ' weight ' in item ] )
id = 0
def init_agent ( item , id = ix ) :
while id in agents :
id + = 1
agent = remaining [ id ]
agent [ ' state ' ] . update ( copy ( item . get ( ' state ' , { } ) ) )
agents [ agent . unique_id ] = agent
del remaining [ id ]
return agent
for item in definition :
if ' ids ' in item :
ids = item [ ' ids ' ]
del item [ ' ids ' ]
for id in ids :
agent = init_agent ( item , id )
for item in definition :
if ' number ' in item :
times = item [ ' number ' ]
del item [ ' number ' ]
for times in range ( times ) :
if size :
ix = random . choice ( remaining . keys ( ) )
agent = init_agent ( item , id )
else :
agent = init_agent ( item )
if not size :
return agents
if len ( remaining ) < 0 :
raise Exception ( ' Invalid definition. Too many agents to add ' )
total_weight = float ( sum ( s [ ' weight ' ] for s in distro ) )
unit = size / total_weight
for item in distro :
times = unit * item [ ' weight ' ]
del item [ ' weight ' ]
for times in range ( times ) :
ix = random . choice ( remaining . keys ( ) )
agent = init_agent ( item , id )
return agents
# def _agent_from_definition(definition, random, value=-1, unique_id=None) :
# """ Used in the initialization of agents given an agent distribution."""
# if value < 0 :
# value = random.random()
# for d in sorted(definition, key=lambda x: x.get('threshold')) :
# threshold = d.get('threshold', (-1, -1))
# # Check if the definition matches by id (first) or by threshold
# if (unique_id is not None and unique_id in d.get('ids', [])) or \
# (value >= threshold[0] and value < threshold[1]):
# state = {}
# if 'state' in d :
# state = deepcopy(d['state'] )
# return d['agent_class'], state
# raise Exception('Definition for value {} not found in: {}'.format(value, definition))
# def _definition_to_dict(definition, random, size=None, default_state=None) :
# state = default_state or {}
# agents = {}
# remaining = {}
# if size:
# for ix in range(size):
# remaining[ix] = copy(state)
# else :
# remaining = defaultdict(lambda x: copy(state))
# distro = sorted([item for item in definition if 'weight' in item])
# id = 0
# def init_agent(item, id=ix) :
# while id in agents :
# id += 1
# agent = remaining[id]
# agent['state'].update(copy(item.get('state', {})))
# agents[agent.unique_id] = agent
# del remaining[id]
# return agent
# for item in definition :
# if 'ids' in item:
# ids = item['ids']
# del item['ids']
# for id in ids:
# agent = init_agent(item, id)
# for item in definition :
# if 'number' in item:
# times = item['number']
# del item['number']
# for times in range(times):
# if size :
# ix = random.choice(remaining.keys() )
# agent = init_agent(item, id)
# else:
# agent = init_agent(item )
# if not size:
# return agents
# if len(remaining) < 0:
# raise Exception('Invalid definition. Too many agents to add')
# total_weight = float(sum(s['weight'] for s in distro))
# unit = size / total_weight
# for item in distro :
# times = unit * item['weight' ]
# del item['weight']
# for times in range(times):
# ix = random.choice(remaining.keys() )
# agent = init_agent(item, id)
# return agents
class AgentView ( Mapping , Set ) :
@ -571,59 +620,43 @@ class AgentView(Mapping, Set):
# Mapping methods
def __len__ ( self ) :
return sum( len( x ) for x in self . _agents . values ( ) )
return len( self . _agents )
def __iter__ ( self ) :
yield from iter ( chain . from_iterable ( g . values ( ) for g in self . _agents . values ( ) ) )
yield from self . _agents . values ( )
def __getitem__ ( self , agent_id ) :
if isinstance ( agent_id , slice ) :
raise ValueError ( f " Slicing is not supported " )
for group in self . _agents . values ( ) :
if agent_id in group :
return group [ agent_id ]
if agent_id in self . _agents :
return self . _agents [ agent_id ]
raise ValueError ( f " Agent { agent_id } not found " )
def filter ( self , * args , * * kwargs ) :
yield from filter_ group s( self . _agents , * args , * * kwargs )
yield from filter_ agent s( self . _agents , * args , * * kwargs )
def one ( self , * args , * * kwargs ) :
return next ( filter_ group s( self . _agents , * args , * * kwargs ) )
return next ( filter_ agent s( self . _agents , * args , * * kwargs ) )
def __call__ ( self , * args , * * kwargs ) :
return list ( self . filter ( * args , * * kwargs ) )
def __contains__ ( self , agent_id ) :
return any ( agent_id in g for g in self . _agents )
return agent_id in self . _agents
def __str__ ( self ) :
return str ( list ( a. unique_id for a in self ) )
return str ( list ( unique_id for unique_id in self . keys ( ) ) )
def __repr__ ( self ) :
return f " { self . __class__ . __name__ } ( { self } ) "
def filter_groups ( groups , * , group = None , * * kwargs ) :
assert isinstance ( groups , dict )
if group is not None and not isinstance ( group , list ) :
group = [ group ]
if group :
groups = list ( groups [ g ] for g in group if g in groups )
else :
groups = list ( groups . values ( ) )
agents = chain . from_iterable ( filter_group ( g , * * kwargs ) for g in groups )
yield from agents
def filter_group ( group , * id_args , unique_id = None , state_id = None , agent_class = None , ignore = None , state = None , * * kwargs ) :
def filter_agents ( agents , * id_args , unique_id = None , state_id = None , agent_class = None , ignore = None , state = None ,
limit = None , * * kwargs ) :
'''
Filter agents given as a dict , by the criteria given as arguments ( e . g . , certain type or state id ) .
'''
assert isinstance ( group , dict )
assert isinstance ( agents , dict )
ids = [ ]
@ -636,6 +669,11 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
if id_args :
ids + = id_args
if ids :
f = ( agents [ aid ] for aid in ids if aid in agents )
else :
f = ( a for a in agents . values ( ) )
if state_id is not None and not isinstance ( state_id , ( tuple , list ) ) :
state_id = tuple ( [ state_id ] )
@ -646,12 +684,6 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
except TypeError :
agent_class = tuple ( [ agent_class ] )
if ids :
agents = ( group [ aid ] for aid in ids if aid in group )
else :
agents = ( a for a in group . values ( ) )
f = agents
if ignore :
f = filter ( lambda x : x not in ignore , f )
@ -667,83 +699,125 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
for k , v in state . items ( ) :
f = filter ( lambda agent : agent . state . get ( k , None ) == v , f )
if limit is not None :
f = islice ( f , limit )
yield from f
def from_config ( cfg : Dict[ str , config . AgentConfig ] , env , random ) :
def from_config ( cfg : config. AgentConfig , random , topologies : Dict [ str , nx . Graph ] = None ) - > List [ Dict [ str , Any ] ] :
'''
Agents are specified in groups .
Each group can be specified in two ways , either through a fixed list in which each item has
has the agent type , number of agents to create , and the other parameters , or through what we call
an ` agent distribution ` , which is similar but instead of number of agents , it specifies the weight
of each agent type .
This function turns an agentconfig into a list of individual " agent specifications " , which are just a dictionary
with the parameters that the environment will use to construct each agent .
This function does NOT return a list of agents , mostly because some attributes to the agent are not known at the
time of calling this function , such as ` unique_id ` .
'''
default = cfg . get ( ' default ' , None )
return { k : _group_from_config ( c , default = default , env = env , random = random ) for ( k , c ) in cfg . items ( ) if k is not ' default ' }
default = cfg or config . AgentConfig ( )
if not isinstance ( cfg , config . AgentConfig ) :
cfg = config . AgentConfig ( * * cfg )
return _agents_from_config ( cfg , topologies = topologies , random = random )
def _group_from_config ( cfg : config . AgentConfig , default : config . SingleAgentConfig , env , random ) :
def _agents_from_config ( cfg : config . AgentConfig ,
topologies : Dict [ str , nx . Graph ] ,
random ) - > List [ Dict [ str , Any ] ] :
if cfg and not isinstance ( cfg , config . AgentConfig ) :
cfg = config . AgentConfig ( * * cfg )
if default and not isinstance ( default , config . SingleAgentConfig ) :
default = config . SingleAgentConfig ( * * default )
agents = { }
agents = [ ]
assigned = defaultdict ( int )
if cfg . fixed is not None :
agents = _from_fixed ( cfg . fixed , topology = cfg . topology , default = default , env = env )
if cfg . distribution :
n = cfg . n or len ( env . topologies [ cfg . topology or default . topology ] )
target = n - len ( agents )
agents . update ( _from_distro ( cfg . distribution , target ,
topology = cfg . topology or default . topology ,
default = default ,
env = env , random = random ) )
assert len ( agents ) == n
if cfg . override :
for attrs in cfg . override :
if attrs . filter :
filtered = list ( filter_group ( agents , * * attrs . filter ) )
else :
filtered = list ( agents )
agents , counts = _from_fixed ( cfg . fixed , topology = cfg . topology , default = cfg )
assigned . update ( counts )
n = cfg . n
if attrs . n > len ( filtered ) :
raise ValueError ( f ' Not enough agents to sample. Got { len ( filtered ) } , expected >= { attrs . n } ' )
for agent in random . sample ( filtered , attrs . n ) :
agent . state . update ( attrs . state )
if cfg . distribution :
topo_size = { top : len ( topologies [ top ] ) for top in topologies }
grouped = defaultdict ( list )
total = [ ]
for d in cfg . distribution :
if d . strategy == config . Strategy . topology :
topology = d . topology if ( ' topology ' in d . __fields_set__ ) else cfg . topology
if not topology :
raise ValueError ( ' The " topology " strategy only works if the topology parameter is specified ' )
if topology not in topo_size :
raise ValueError ( f ' Unknown topology selected: { topology } . Make sure the topology has been defined ' )
grouped [ topology ] . append ( d )
if d . strategy == config . Strategy . total :
if not cfg . n :
raise ValueError ( ' Cannot use the " total " strategy without providing the total number of agents ' )
total . append ( d )
for ( topo , distro ) in grouped . items ( ) :
if not topologies or topo not in topo_size :
raise ValueError (
' You need to specify a target number of agents for the distribution \
or a configuration with a topology , along with a dictionary with \
all the available topologies ' )
n = len ( topologies [ topo ] )
target = topo_size [ topo ] - assigned [ topo ]
new_agents = _from_distro ( cfg . distribution , target ,
topology = topo ,
default = cfg ,
random = random )
assigned [ topo ] + = len ( new_agents )
agents + = new_agents
if total :
remaining = n - sum ( assigned . values ( ) )
agents + = _from_distro ( total , remaining ,
topology = ' ' , # DO NOT assign to any topology
default = cfg ,
random = random )
if sum ( assigned . values ( ) ) != sum ( topo_size . values ( ) ) :
utils . logger . warn ( f ' The total number of agents does not match the total number of nodes in '
' every topology. This may be due to a definition error: assigned: '
f ' { assigned } total sizes: { topo_size } ' )
return agents
def _from_fixed ( lst : List [ config . FixedAgentConfig ] , topology : str , default : config . SingleAgentConfig , env ) :
agents = { }
def _from_fixed ( lst : List [ config . FixedAgentConfig ] , topology : str , default : config . SingleAgentConfig ) - > List [ Dict [ str , Any ] ] :
agents = [ ]
counts = { }
for fixed in lst :
agent_id = fixed . agent_id
if agent_id is None :
agent_id = env . next_id ( )
cls = serialization . deserialize ( fixed . agent_class or default . agent_class )
state = fixed . state . copy ( )
state . update ( default . state )
agent = cls ( unique_id = agent_id ,
model = env ,
* * state )
topology = fixed . topology if ( fixed . topology is not None ) else ( topology or default . topology )
if topology :
env . agent_to_node ( agent_id , topology , fixed . node_id )
agents [ agent . unique_id ] = agent
agent = { }
if default :
agent = default . state . copy ( )
agent . update ( fixed . state )
cls = serialization . deserialize ( fixed . agent_class or ( default and default . agent_class ) )
agent [ ' agent_class ' ] = cls
topo = fixed . topology if ( ' topology ' in fixed . __fields_set__ ) else topology or default . topology
return agents
if topo :
agent [ ' topology ' ] = topo
if not fixed . hidden :
counts [ topo ] = counts . get ( topo , 0 ) + 1
agents . append ( agent )
return agents , counts
def _from_distro ( distro : List [ config . AgentDistro ] ,
n : int ,
topology : str ,
default : config . SingleAgentConfig ,
env ,
random ) :
random ) - > List [ Dict [ str , Any ] ] :
agents = { }
agents = []
if n is None :
if any ( lambda dist : dist . n is None , distro ) :
@ -775,19 +849,16 @@ def _from_distro(distro: List[config.AgentDistro],
for idx in indices :
d = distro [ idx ]
agent = d . state . copy ( )
cls = classes [ idx ]
agent_id = env . next_id ( )
state = d . state . copy ( )
agent [ ' agent_class ' ] = cls
if default :
state . update ( default . state )
agent = cls ( unique_id = agent_id , model = env , * * state )
topology = d . topology if ( d . topology is not None ) else topology or default . topology
agent . update ( default . state )
# agent = cls(unique_id=agent_id, model=env, **state )
topology = d . topology if ( ' topology ' in d . __fields_set__ ) else topology or default . topology
if topology :
env . agent_to_node ( agent . unique_id , topology )
assert agent . name is not None
assert agent . name != ' None '
assert agent . name
agents [ agent . unique_id ] = agent
agent [ ' topology ' ] = topology
agents . append ( agent )
return agents