diff --git a/vivarium/environments/braitenberg/selective_sensing/init.py b/vivarium/environments/braitenberg/selective_sensing/init.py index 192af4a..34303dd 100644 --- a/vivarium/environments/braitenberg/selective_sensing/init.py +++ b/vivarium/environments/braitenberg/selective_sensing/init.py @@ -98,6 +98,13 @@ def get_agents_params_and_sensed_arr(agents_stacked_behaviors_list): return params, sensed, behaviors def get_positions(positions, n, box_size): + """Check if the positions are valid and return them if they are + + :param positions: positions of the entities + :param n: number of entities + :param box_size: size of the box + :return: positions + """ if positions is None: return [None] * n assert len(positions) == n, f"The number of positions: {len(positions)} must match the number of entities: {n}" @@ -107,6 +114,12 @@ def get_positions(positions, n, box_size): return positions def check_position_redundancies(agents_pos, objects_pos): + """Check if there are redundant positions in the agents and objects positions + + :param agents_pos: agents positions + :param objects_pos: objects positions + :return: redundant_positions + """ positions = agents_pos + objects_pos position_indices = defaultdict(list) @@ -119,13 +132,24 @@ def check_position_redundancies(agents_pos, objects_pos): return redundant_positions if (len(redundant_positions) > 0) else False def get_exists(exists, n): + """Check if the exists array is valid and return it if it is + + :param exists: exists array + :param n: number of entities + :return: exists + """ if exists is None: return [1] * n - assert isinstance(exists, int) and (exists < n), f"Exists must be an int inferior than {n}, {exists} is not" + assert isinstance(exists, int) and (exists <= n), f"Exists must be an int inferior or equal than {n}, {exists} is not" exists = [1] * exists + [None] * (n - exists) return exists def set_to_none_if_all_none(lst): + """Set the list to None if all elements are None + + :param lst: list to check + :return: lst + """ if not any(element is not None for element in lst): return None return lst @@ -286,7 +310,6 @@ def init_objects( color=objects_color ) - def init_complete_state( entities, agents, @@ -317,8 +340,13 @@ def init_complete_state( ent_sub_types=total_ent_sub_types ) - def process_entity(data, box_size): + """Process the entity data to extract the color, positions, exists and diameter + + :param data: entity data + :param box_size: box size + :return: entity_data, positions, exists, diameter_lst + """ n = data['num'] color_str = data['color'] color = _string_to_rgb_array(color_str) @@ -328,7 +356,7 @@ def process_entity(data, box_size): diameter_lst = [diameter] * n return {'n': n, 'color': color}, positions, exists, diameter_lst - +# TODO : should refactor all the yaml configs to only define a dict of agents and a dict of objects, instead of specifying it for each subtype of entities def init_state( entities_data=CONFIG.entities_data, box_size=CONFIG.box_size, @@ -360,6 +388,11 @@ def init_state( ent_sub_types_enum = Enum('ent_sub_types_enum', {ent_sub_types[i]: i for i in range(len(ent_sub_types))}) ent_data = entities_data['Entities'] + # check if at least one agent and one object are defined in the entities data + has_agent, has_object = check_agent_and_object(ent_data) + assert has_agent, "At least one agent must be defined in the entities data" + assert has_object, "At least one object must be defined in the entities data" + # create max agents and max objects max_agents = 0 max_objects = 0 @@ -529,3 +562,18 @@ def init_state( return state +def check_agent_and_object(ent_data): + has_agent = False + has_object = False + + for entity in ent_data.values(): + if entity['type'] == 'AGENT': + has_agent = True + elif entity['type'] == 'OBJECT': + has_object = True + + # If both are found, no need to continue checking + if has_agent and has_object: + break + + return has_agent, has_object