Skip to content

Commit

Permalink
Merge pull request #81 from opapy/feature/add-endpoint-url-params
Browse files Browse the repository at this point in the history
add --endpint-url params
  • Loading branch information
spulec authored Nov 29, 2021
2 parents 1aeef3b + 05c7487 commit a983ce1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
16 changes: 14 additions & 2 deletions pyqs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def main():
action="store",
)

parser.add_argument(
"--endpoint-url",
dest="endpoint_url",
type=str,
default=None,
help="AWS SQS endpoint url",
action="store",
)

parser.add_argument(
"--interval",
dest="interval",
Expand Down Expand Up @@ -143,7 +152,8 @@ def main():
interval=args.interval,
batchsize=_set_batchsize(args),
prefetch_multiplier=args.prefetch_multiplier,
simple_worker=args.simple_worker
simple_worker=args.simple_worker,
endpoint_url=args.endpoint_url,
)


Expand All @@ -156,7 +166,7 @@ def _add_cwd_to_path():
def _main(queue_prefixes, concurrency=5, logging_level="WARN",
region=None, access_key_id=None, secret_access_key=None,
interval=1, batchsize=DEFAULT_BATCH_SIZE, prefetch_multiplier=2,
simple_worker=False):
simple_worker=False, endpoint_url=None):
logging.basicConfig(
format="[%(levelname)s]: %(message)s",
level=getattr(logging, logging_level),
Expand All @@ -168,12 +178,14 @@ def _main(queue_prefixes, concurrency=5, logging_level="WARN",
queue_prefixes, concurrency, interval, batchsize,
region=region, access_key_id=access_key_id,
secret_access_key=secret_access_key,
endpoint_url=endpoint_url,
)
else:
manager = ManagerWorker(
queue_prefixes, concurrency, interval, batchsize,
prefetch_multiplier=prefetch_multiplier, region=region,
access_key_id=access_key_id, secret_access_key=secret_access_key,
endpoint_url=endpoint_url,
)

_add_cwd_to_path()
Expand Down
33 changes: 23 additions & 10 deletions pyqs/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,23 @@
logger = logging.getLogger("pyqs")


def get_conn(region=None, access_key_id=None, secret_access_key=None):
if not region:
region = get_aws_region_name()
def get_conn(
region=None, access_key_id=None, secret_access_key=None, endpoint_url=None
):
kwargs = {
"aws_access_key_id": access_key_id,
"aws_secret_access_key": secret_access_key,
"region_name": region,
}

if endpoint_url:
kwargs["endpoint_url"] = endpoint_url
if not kwargs["region_name"]:
kwargs["region_name"] = get_aws_region_name()

return boto3.client(
"sqs",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key, region_name=region,
**kwargs,
)


Expand Down Expand Up @@ -359,11 +368,12 @@ class BaseManager(object):

def __init__(self, queue_prefixes, interval, batchsize,
region=None, access_key_id=None,
secret_access_key=None):
secret_access_key=None, endpoint_url=None):
self.connection_args = {
"region": region,
"access_key_id": access_key_id,
"secret_access_key": secret_access_key,
"endpoint_url": endpoint_url,
}
self.interval = interval
self.batchsize = batchsize
Expand Down Expand Up @@ -443,12 +453,14 @@ class SimpleManagerWorker(BaseManager):
WORKER_CHILDREN_CLASS = SimpleProcessWorker

def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize,
region=None, access_key_id=None, secret_access_key=None):
region=None, access_key_id=None, secret_access_key=None,
endpoint_url=None):

super(SimpleManagerWorker, self).__init__(queue_prefixes, interval,
batchsize, region,
access_key_id,
secret_access_key)
secret_access_key,
endpoint_url)

self.worker_children = []
self._initialize_worker_children(worker_concurrency)
Expand Down Expand Up @@ -513,12 +525,13 @@ class ManagerWorker(BaseManager):

def __init__(self, queue_prefixes, worker_concurrency, interval, batchsize,
prefetch_multiplier=2, region=None, access_key_id=None,
secret_access_key=None):
secret_access_key=None, endpoint_url=None):

super(ManagerWorker, self).__init__(queue_prefixes, interval,
batchsize, region,
access_key_id,
secret_access_key)
secret_access_key,
endpoint_url)

self.prefetch_multiplier = prefetch_multiplier
self.worker_children = []
Expand Down

0 comments on commit a983ce1

Please sign in to comment.