Skip to content

Commit

Permalink
Merge pull request #330 from wguanicedew/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
wguanicedew authored Aug 15, 2024
2 parents 947ab98 + 453f67a commit 7e5d559
Show file tree
Hide file tree
Showing 16 changed files with 371 additions and 164 deletions.
1 change: 1 addition & 0 deletions common/lib/idds/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class MessageSource(IDDSEnum):
Carrier = 3
Conductor = 4
Rest = 5
OutSide = 6


class MessageDestination(IDDSEnum):
Expand Down
47 changes: 23 additions & 24 deletions main/lib/idds/agents/carrier/iutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,20 @@ def handle_messages_asyncresult(messages, logger=None, log_prefix='', update_pro

req_msgs = {}

for msg in messages:
if 'from_idds' in msg and msg['from_idds']:
continue
for item in messages:
if 'from_idds' in item:
if type(item['from_idds']) in [bool] and item['from_idds'] or type(item['from_idds']) in [str] and item['from_idds'].lower() == 'true':
continue

msg = item['msg']

# ret = msg['ret']
# key = msg['key']
# internal_id = msg['internal_id']
# msg_type = msg['type']
request_id = msg['request_id']
transform_id = msg.get('transform_id', 0)
internal_id = msg.get('internal_id', None)
request_id = msg['body']['request_id']
transform_id = msg['body'].get('transform_id', 0)
internal_id = msg['body'].get('internal_id', None)
# if msg_type in ['iworkflow']:

if request_id not in req_msgs:
Expand All @@ -203,21 +206,17 @@ def handle_messages_asyncresult(messages, logger=None, log_prefix='', update_pro
req_msgs[request_id][transform_id] = {}
if internal_id not in req_msgs[request_id][transform_id]:
req_msgs[request_id][transform_id][internal_id] = []
req_msgs[request_id][transform_id][internal_id].append(msg)

for request_id in req_msgs:
for transform_id in req_msgs[request_id]:
for internal_id in req_msgs[request_id][transform_id]:
msgs = req_msgs[request_id][transform_id][internal_id]
core_messages.add_message(msg_type=MessageType.AsyncResult,
status=MessageStatus.NoNeedDelivery,
destination=MessageDestination.AsyncResult,
source=MessageSource.Outside,
request_id=request_id,
workload_id=None,
transform_id=transform_id,
internal_id=internal_id,
num_contents=len(msgs),
msg_content=msgs)

logger.debug(f"{log_prefix} handle_messages_asyncresult, add {len(msgs)} for request_id {request_id} transform_id {transform_id} internal_id {internal_id}")

msgs = [msg]
core_messages.add_message(msg_type=MessageType.AsyncResult,
status=MessageStatus.NoNeedDelivery,
destination=MessageDestination.AsyncResult,
source=MessageSource.OutSide,
request_id=request_id,
workload_id=None,
transform_id=transform_id,
internal_id=internal_id,
num_contents=len(msgs),
msg_content=msgs)

logger.debug(f"{log_prefix} handle_messages_asyncresult, add {len(msgs)} for request_id {request_id} transform_id {transform_id} internal_id {internal_id}")
13 changes: 9 additions & 4 deletions main/lib/idds/agents/carrier/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def get_output_messages(self):
if msg_size < 10:
self.logger.debug("Received message(only log first 10 messages): %s" % str(msg))
name = msg['name']
body = msg['body']
# headers = msg['headers']
# body = msg['body']
# from_idds = msg['from_idds']
if name not in msgs:
msgs[name] = []
msgs[name].append(body)
msgs[name].append(msg)
msg_size += 1
if msg_size >= self.bulk_message_size:
break
Expand Down Expand Up @@ -151,7 +153,10 @@ def add_receiver_monitor_task(self):
self.add_task(task)

def handle_messages(self, output_messages, log_prefix):
ret_msg_handle = handle_messages_processing(output_messages,
output_messages_new = []
for msg in output_messages:
output_messages_new.append(msg['msg']['body'])
ret_msg_handle = handle_messages_processing(output_messages_new,
logger=self.logger,
log_prefix=log_prefix,
update_processing_interval=self.update_processing_interval)
Expand Down Expand Up @@ -202,7 +207,7 @@ def handle_messages_asyncresult(self, output_messages, log_prefix):

def handle_messages_channels(self, output_messages, log_prefix):
for channel in output_messages:
if channel in ['asyncresult']:
if channel in ['asyncresult', 'AsyncResult']:
self.handle_messages_asyncresult(output_messages[channel], log_prefix)
else:
self.handle_messages(output_messages[channel], log_prefix)
Expand Down
42 changes: 29 additions & 13 deletions main/lib/idds/agents/common/plugins/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def on_error(self, frame):
self.logger.error('[broker] [%s]: %s', self.__broker, frame.body)

def on_message(self, frame):
self.logger.debug('[broker] %s [%s]: %s', self.name, self.__broker, frame.body)
self.__output_queue.put({'name': self.name, 'msg': frame.body})
self.logger.debug(f'[broker] {self.name} [{self.__broker}]: headers: {frame.headers}, body: {frame.body}')
headers = frame.headers
from_idds = headers.get('from_idds', 'false')
self.__output_queue.put({'name': self.name, 'from_idds': from_idds, 'msg': {'headers': frame.headers, 'body': json_loads(frame.body)}})
pass


Expand Down Expand Up @@ -209,17 +211,30 @@ def send_message(self, msg):
destination = msg['destination'] if 'destination' in msg else 'default'
conn, queue_dest, destination = self.get_connection(destination)

from_idds = 'false'
if 'from_idds' in msg and msg['from_idds']:
from_idds = 'true'

if conn:
self.logger.info("Sending message to message broker(%s): %s" % (destination, msg['msg_id']))
self.logger.debug("Sending message to message broker(%s): %s" % (destination, json_dumps(msg['msg_content'])))
conn.send(body=json_dumps(msg['msg_content']),
destination=queue_dest,
id='atlas-idds-messaging',
ack='auto',
headers={'persistent': 'true',
'ttl': self.timetolive,
'vo': 'atlas',
'msg_type': str(msg['msg_type']).lower()})
if type(msg['msg_content']) in [dict] and 'headers' in msg['msg_content'] and 'body' in msg['msg_content']:
msg['msg_content']['headers']['from_idds'] = from_idds
conn.send(body=json_dumps(msg['msg_content']['body']),
headers=msg['msg_content']['headers'],
destination=queue_dest,
id='atlas-idds-messaging',
ack='auto')
else:
conn.send(body=json_dumps(msg['msg_content']),
destination=queue_dest,
id='atlas-idds-messaging',
ack='auto',
headers={'persistent': 'true',
'ttl': self.timetolive,
'vo': 'atlas',
'from_idds': from_idds,
'msg_type': str(msg['msg_type']).lower()})
else:
self.logger.info("No brokers defined, discard(%s): %s" % (destination, msg['msg_id']))

Expand Down Expand Up @@ -260,8 +275,9 @@ def __init__(self, name="MessagingReceiver", logger=None, **kwargs):
def get_listener(self, broker, name):
if self.listener is None:
self.listener = {}
self.listener[name] = MessagingListener(broker, self.output_queue, name=name, logger=self.logger)
return self.listener
if name not in self.listener:
self.listener[name] = MessagingListener(broker, self.output_queue, name=name, logger=self.logger)
return self.listener[name]

def subscribe(self):
self.receiver_conns = self.connect_to_messaging_brokers()
Expand Down Expand Up @@ -298,7 +314,7 @@ def execute_subscribe(self):
for name in self.receiver_conns:
for conn in self.receiver_conns[name]:
if not conn.is_connected():
conn.set_listener('message-receiver', self.get_listener(conn.transport._Transport__host_and_ports[0]))
conn.set_listener('message-receiver', self.get_listener(conn.transport._Transport__host_and_ports[0], name))
# conn.start()
conn.connect(self.channels[name]['username'], self.channels[name]['password'], wait=True)
conn.subscribe(destination=self.channels[name]['destination'], id='atlas-idds-messaging', ack='auto')
Expand Down
5 changes: 4 additions & 1 deletion main/lib/idds/agents/conductor/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_messages(self):
if BaseAgent.min_request_id is None:
return []

destination = [MessageDestination.Outside, MessageDestination.ContentExt]
destination = [MessageDestination.Outside, MessageDestination.ContentExt, MessageDestination.AsyncResult]
messages = core_messages.retrieve_messages(status=MessageStatus.New,
min_request_id=BaseAgent.min_request_id,
bulk_size=self.retrieve_bulk_size,
Expand Down Expand Up @@ -196,6 +196,8 @@ def is_message_processed(self, message):
self.logger.info("message %s has reached max retries %s" % (message['msg_id'], self.max_retries))
return True
msg_type = message['msg_type']
if msg_type in [MessageType.AsyncResult]:
return True
if msg_type not in [MessageType.ProcessingFile]:
if retries < self.replay_times:
return False
Expand Down Expand Up @@ -286,6 +288,7 @@ def run(self):
to_discard_messages = []
for message in messages:
message['destination'] = message['destination'].name
message['from_idds'] = True

num_contents += message['num_contents']
if self.is_message_processed(message):
Expand Down
4 changes: 2 additions & 2 deletions main/lib/idds/core/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def add_messages(messages, bulk_size=1000, session=None):
def retrieve_messages(bulk_size=None, msg_type=None, status=None, destination=None,
source=None, request_id=None, workload_id=None, transform_id=None,
processing_id=None, use_poll_period=False, retries=None, delay=None,
min_request_id=None, fetching_id=None, session=None):
min_request_id=None, fetching_id=None, internal_id=None, session=None):
"""
Retrieve up to $bulk messages.
Expand All @@ -71,7 +71,7 @@ def retrieve_messages(bulk_size=None, msg_type=None, status=None, destination=No
request_id=request_id, workload_id=workload_id,
transform_id=transform_id, processing_id=processing_id,
retries=retries, delay=delay, fetching_id=fetching_id,
min_request_id=min_request_id,
min_request_id=min_request_id, internal_id=internal_id,
use_poll_period=use_poll_period, session=session)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <[email protected]>, 2024

"""add conditions and campaign
Revision ID: 3073c5de8f73
Revises: 40ead97e63c6
Create Date: 2024-08-05 13:21:37.265614+00:00
"""
import datetime

from alembic import op
from alembic import context
import sqlalchemy as sa

from idds.common.constants import ConditionStatus
from idds.orm.base.types import EnumWithValue
from idds.orm.base.types import JSON

# revision identifiers, used by Alembic.
revision = '3073c5de8f73'
down_revision = '40ead97e63c6'
branch_labels = None
depends_on = None


def upgrade() -> None:
if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']:
schema = context.get_context().version_table_schema if context.get_context().version_table_schema else ''

op.add_column('requests', sa.Column('campaign', sa.String(100)), schema=schema)
op.add_column('requests', sa.Column('campaign_group', sa.String(250)), schema=schema)
op.add_column('requests', sa.Column('campaign_tag', sa.String(20)), schema=schema)

op.add_column('transforms', sa.Column('internal_id', sa.String(20)), schema=schema)
op.add_column('transforms', sa.Column('has_previous_conditions', sa.Integer()), schema=schema)
op.add_column('transforms', sa.Column('loop_index', sa.Integer()), schema=schema)
op.add_column('transforms', sa.Column('cloned_from', sa.BigInteger()), schema=schema)
op.add_column('transforms', sa.Column('triggered_conditions', JSON()), schema=schema)
op.add_column('transforms', sa.Column('untriggered_conditions', JSON()), schema=schema)

op.create_table('conditions',
sa.Column('condition_id', sa.BigInteger(), sa.Sequence('CONDITION_ID_SEQ', schema=schema)),
sa.Column('request_id', sa.BigInteger(), nullable=False),
sa.Column('internal_id', sa.String(20), nullable=False),
sa.Column('status', EnumWithValue(ConditionStatus), nullable=False),
sa.Column('is_loop', sa.Integer()),
sa.Column('loop_index', sa.Integer()),
sa.Column('cloned_from', sa.BigInteger()),
sa.Column("created_at", sa.DateTime, default=datetime.datetime.utcnow, nullable=False),
sa.Column("updated_at", sa.DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False),
sa.Column("evaluate_result", sa.String(200)),
sa.Column('previous_transforms', JSON()),
sa.Column('following_transforms', JSON()),
sa.Column('condition', JSON()),
schema=schema)
op.create_primary_key('CONDITION_PK', 'conditions', ['condition_id'], schema=schema)
op.create_unique_constraint('CONDITION_ID_UQ', 'conditions', ['request_id', 'internal_id'], schema=schema)


def downgrade() -> None:
if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']:
schema = context.get_context().version_table_schema if context.get_context().version_table_schema else ''

op.drop_column('requests', 'campaign', schema=schema)
op.drop_column('requests', 'campaign_group', schema=schema)
op.drop_column('requests', 'campaign_tag', schema=schema)

op.drop_column('transforms', 'internal_id', schema=schema)
op.drop_column('transforms', 'has_previous_conditions', schema=schema)
op.drop_column('transforms', 'loop_index', schema=schema)
op.drop_column('transforms', 'cloned_from', schema=schema)
op.drop_column('transforms', 'triggered_conditions', schema=schema)
op.drop_column('transforms', 'untriggered_conditions', schema=schema)

op.drop_constraint('CONDITION_ID_UQ', table_name='conditions', schema=schema)
op.drop_constraint('CONDITION_PK', table_name='conditions', schema=schema)
op.drop_table('conditions', schema=schema)
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <[email protected]>, 2024

"""messages table add internal_id
Revision ID: 40ead97e63c6
Revises: cc9f730e54c5
Create Date: 2024-07-01 14:02:47.670000+00:00
"""
from alembic import op
from alembic import context
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '40ead97e63c6'
down_revision = 'cc9f730e54c5'
branch_labels = None
depends_on = None


def upgrade() -> None:
if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']:
schema = context.get_context().version_table_schema if context.get_context().version_table_schema else ''
op.add_column('messages', sa.Column('internal_id', sa.String(20)), schema=schema)


def downgrade() -> None:
if context.get_context().dialect.name in ['oracle', 'mysql', 'postgresql']:
schema = context.get_context().version_table_schema if context.get_context().version_table_schema else ''
op.drop_column('messages', 'internal_id', schema=schema)
4 changes: 3 additions & 1 deletion main/lib/idds/orm/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def retrieve_messages(bulk_size=1000, msg_type=None, status=None, source=None,
destination=None, request_id=None, workload_id=None,
transform_id=None, processing_id=None, fetching_id=None,
min_request_id=None, use_poll_period=False, retries=None,
delay=None, session=None):
delay=None, internal_id=None, session=None):
"""
Retrieve up to $bulk messages.
Expand Down Expand Up @@ -183,6 +183,8 @@ def retrieve_messages(bulk_size=1000, msg_type=None, status=None, source=None,
query = query.filter_by(transform_id=transform_id)
if processing_id is not None:
query = query.filter_by(processing_id=processing_id)
if internal_id is not None:
query = query.filter_by(internal_id=internal_id)
if retries:
query = query.filter_by(retries=retries)
if delay:
Expand Down
Loading

0 comments on commit 7e5d559

Please sign in to comment.