diff --git a/src/handler.py b/src/handler.py index 0f0e375..93de246 100644 --- a/src/handler.py +++ b/src/handler.py @@ -45,6 +45,19 @@ def generator_handler(job): yield output +async def async_generator_handler(job): + ''' + Async generator type handler. + ''' + job_input = _side_effects(job['input']) + + # Prepare the job output + job_output = job_input.get('mock_return', ['Hello World!']) + + for output in job_output: + yield output + + # ------------------------------- Side Effects ------------------------------- # def _side_effects(job_input): ''' @@ -72,6 +85,8 @@ def _side_effects(job_input): parser = argparse.ArgumentParser() parser.add_argument('--generator', action='store_true', default=False, help='Starts serverless with the generator_handler') + parser.add_argument('--async_generator', action='store_true', default=False, + help='Starts serverless with the async_generator_handler') parser.add_argument('--return_aggregate_stream', action='store_true', default=False, help='Aggregate the stream of generator_handler and return it as a list') args = parser.parse_args() @@ -86,5 +101,14 @@ def _side_effects(job_input): "return_aggregate_stream": args.return_aggregate_stream }) + elif args.async_generator: + print('Starting serverless with async_generator_handler') + print(f"return_aggregate_stream: {args.return_aggregate_stream}") + + runpod.serverless.start({ + "handler": async_generator_handler, + "return_aggregate_stream": args.return_aggregate_stream + }) + else: runpod.serverless.start({"handler": handler})