From 05c74877410475d2e90d20747f040c6831bc3b14 Mon Sep 17 00:00:00 2001 From: tkhr Date: Mon, 29 Nov 2021 19:48:07 +0900 Subject: [PATCH] add --endpint-url params add --endpint-url params --- pyqs/main.py | 16 ++++++++++++++-- pyqs/worker.py | 33 +++++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pyqs/main.py b/pyqs/main.py index 7dec331..f2fd0c6 100644 --- a/pyqs/main.py +++ b/pyqs/main.py @@ -94,6 +94,15 @@ def main(): action="store", ) + parser.add_argument( + "--endpoint-url", + dest="endpoint_url", + type=str, + default=None, + help="AWS SQS endpoint url", + action="store", + ) + parser.add_argument( "--interval", dest="interval", @@ -143,7 +152,8 @@ def main(): interval=args.interval, batchsize=_set_batchsize(args), prefetch_multiplier=args.prefetch_multiplier, - simple_worker=args.simple_worker + simple_worker=args.simple_worker, + endpoint_url=args.endpoint_url, ) @@ -156,7 +166,7 @@ def _add_cwd_to_path(): def _main(queue_prefixes, concurrency=5, logging_level="WARN", region=None, access_key_id=None, secret_access_key=None, interval=1, batchsize=DEFAULT_BATCH_SIZE, prefetch_multiplier=2, - simple_worker=False): + simple_worker=False, endpoint_url=None): logging.basicConfig( format="[%(levelname)s]: %(message)s", level=getattr(logging, logging_level), @@ -168,12 +178,14 @@ def _main(queue_prefixes, concurrency=5, logging_level="WARN", queue_prefixes, concurrency, interval, batchsize, region=region, access_key_id=access_key_id, secret_access_key=secret_access_key, + endpoint_url=endpoint_url, ) else: manager = ManagerWorker( queue_prefixes, concurrency, interval, batchsize, prefetch_multiplier=prefetch_multiplier, region=region, access_key_id=access_key_id, secret_access_key=secret_access_key, + endpoint_url=endpoint_url, ) _add_cwd_to_path() diff --git a/pyqs/worker.py b/pyqs/worker.py index bed100d..4e655ed 100644 --- a/pyqs/worker.py +++ b/pyqs/worker.py @@ -27,14 +27,23 @@ logger = logging.getLogger("pyqs") -def get_conn(region=None, access_key_id=None, secret_access_key=None): - if not region: - region = get_aws_region_name() +def get_conn( + region=None, access_key_id=None, secret_access_key=None, endpoint_url=None +): + kwargs = { + "aws_access_key_id": access_key_id, + "aws_secret_access_key": secret_access_key, + "region_name": region, + } + + if endpoint_url: + kwargs["endpoint_url"] = endpoint_url + if not kwargs["region_name"]: + kwargs["region_name"] = get_aws_region_name() return boto3.client( "sqs", - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, region_name=region, + **kwargs, ) @@ -359,11 +368,12 @@ class BaseManager(object): def __init__(self, queue_prefixes, interval, batchsize, region=None, access_key_id=None, - secret_access_key=None): + secret_access_key=None, endpoint_url=None): self.connection_args = { "region": region, "access_key_id": access_key_id, "secret_access_key": secret_access_key, + "endpoint_url": endpoint_url, } self.interval = interval self.batchsize = batchsize @@ -443,12 +453,14 @@ class SimpleManagerWorker(BaseManager): WORKER_CHILDREN_CLASS = SimpleProcessWorker def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize, - region=None, access_key_id=None, secret_access_key=None): + region=None, access_key_id=None, secret_access_key=None, + endpoint_url=None): super(SimpleManagerWorker, self).__init__(queue_prefixes, interval, batchsize, region, access_key_id, - secret_access_key) + secret_access_key, + endpoint_url) self.worker_children = [] self._initialize_worker_children(worker_concurrency) @@ -513,12 +525,13 @@ class ManagerWorker(BaseManager): def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize, prefetch_multiplier=2, region=None, access_key_id=None, - secret_access_key=None): + secret_access_key=None, endpoint_url=None): super(ManagerWorker, self).__init__(queue_prefixes, interval, batchsize, region, access_key_id, - secret_access_key) + secret_access_key, + endpoint_url) self.prefetch_multiplier = prefetch_multiplier self.worker_children = []