diff --git a/fixcore/fixcore/cli/__init__.py b/fixcore/fixcore/cli/__init__.py index cb93c9334d..3af2a601e3 100644 --- a/fixcore/fixcore/cli/__init__.py +++ b/fixcore/fixcore/cli/__init__.py @@ -15,13 +15,12 @@ AsyncIterable, ) -from aiostream import stream -from aiostream.core import Stream from parsy import Parser, regex, string from fixcore.model.graph_access import Section from fixcore.types import JsonElement, Json from fixcore.util import utc, parse_utc, AnyT +from fixlib.asynchronous.stream import Stream from fixlib.durations import parse_duration, DurationRe from fixlib.parse_util import ( make_parser, @@ -47,7 +46,7 @@ # A sink function takes a stream and creates a result Sink = Callable[[JsStream], Awaitable[T]] -list_sink: Callable[[JsGen], Awaitable[Any]] = stream.list # type: ignore +list_sink: Callable[[JsGen], Awaitable[List[Any]]] = Stream.as_list @make_parser diff --git a/fixcore/fixcore/cli/command.py b/fixcore/fixcore/cli/command.py index b582cba90d..d94e25a967 100644 --- a/fixcore/fixcore/cli/command.py +++ b/fixcore/fixcore/cli/command.py @@ -29,7 +29,6 @@ Optional, Any, AsyncIterator, - Iterable, Callable, Awaitable, cast, @@ -46,9 +45,6 @@ import yaml from aiofiles.tempfile import TemporaryDirectory from aiohttp import ClientTimeout, JsonPayload, BasicAuth -from aiostream import stream, pipe -from aiostream.aiter_utils import is_async_iterable -from aiostream.core import Stream from attr import evolve, frozen from attrs import define, field from dateutil import parser as date_parser @@ -178,6 +174,7 @@ respond_cytoscape, ) from fixcore.worker_task_queue import WorkerTask, WorkerTaskName +from fixlib.asynchronous.stream import Stream from fixlib.core import CLIEnvelope from fixlib.durations import parse_duration from fixlib.parse_util import ( @@ -946,14 +943,13 @@ def group(keys: tuple[Any, ...]) -> Json: return result async def aggregate_data(content: JsStream) -> AsyncIterator[JsonElement]: - async with content.stream() as in_stream: - for key, value in (await self.aggregate_in(in_stream, var_names, aggregate.group_func)).items(): - entry: Json = {"group": group(key)} - for fn_name, (fn_val, fn_count) in value.fn_values.items(): - if fn_by_name.get(fn_name) == "avg" and fn_val is not None and fn_count > 0: - fn_val = fn_val / fn_count # type: ignore - entry[fn_name] = fn_val - yield entry + for key, value in (await self.aggregate_in(content, var_names, aggregate.group_func)).items(): + entry: Json = {"group": group(key)} + for fn_name, (fn_val, fn_count) in value.fn_values.items(): + if fn_by_name.get(fn_name) == "avg" and fn_val is not None and fn_count > 0: + fn_val = fn_val / fn_count # type: ignore + entry[fn_name] = fn_val + yield entry # noinspection PyTypeChecker return CLIFlow(aggregate_data) @@ -1000,7 +996,7 @@ def info(self) -> str: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: size = self.parse_size(arg) - return CLIFlow(lambda in_stream: in_stream | pipe.take(size)) + return CLIFlow(lambda in_stream: Stream(in_stream).take(size)) def args_info(self) -> ArgsInfo: return [ArgInfo(expects_value=True, help_text="number of elements to take")] @@ -1054,7 +1050,7 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: size = HeadCommand.parse_size(arg) - return CLIFlow(lambda in_stream: in_stream | pipe.takelast(size)) + return CLIFlow(lambda in_stream: Stream(in_stream).take_last(size)) class CountCommand(SearchCLIPart): @@ -1145,9 +1141,8 @@ def inc_identity(_: Any) -> None: fn = inc_prop if arg else inc_identity async def count_in_stream(content: JsStream) -> AsyncIterator[JsonElement]: - async with content.stream() as in_stream: - async for element in in_stream: - fn(element) + async for element in content: + fn(element) for key, value in sorted(counter.items(), key=lambda x: x[1]): yield f"{key}: {value}" @@ -1194,7 +1189,7 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLISource: return CLISource.single( - lambda: stream.just(strip_quotes(arg if arg else "")), required_permissions={Permission.read} + lambda: Stream.just(strip_quotes(arg if arg else "")), required_permissions={Permission.read} ) @@ -1256,7 +1251,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa else: raise AttributeError(f"json does not understand {arg}.") return CLISource.with_count( - lambda: stream.iterate(elements), len(elements), required_permissions={Permission.read} + lambda: Stream.iterate(elements), len(elements), required_permissions={Permission.read} ) @@ -1339,19 +1334,17 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa async def to_count(in_stream: JsStream) -> AsyncIterator[JsonElement]: null_value = 0 total = 0 - in_streamer = in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream) - async with in_streamer.stream() as streamer: - async for elem in streamer: - name = js_value_at(elem, name_path) - count = js_value_get(elem, count_path, 0) - if name is None: - null_value = count - else: - total += count - yield f"{name}: {count}" - tm, tu = (total, null_value) if arg else (null_value + total, 0) - yield f"total matched: {tm}" - yield f"total unmatched: {tu}" + async for elem in in_stream: + name = js_value_at(elem, name_path) + count = js_value_get(elem, count_path, 0) + if name is None: + null_value = count + else: + total += count + yield f"{name}: {count}" + tm, tu = (total, null_value) if arg else (null_value + total, 0) + yield f"total matched: {tm}" + yield f"total unmatched: {tu}" return CLIFlow(to_count) @@ -1550,7 +1543,7 @@ def args_info(self) -> ArgsInfo: return [] def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLISource: - return CLISource.with_count(lambda: stream.just(ctx.env), len(ctx.env), required_permissions={Permission.read}) + return CLISource.with_count(lambda: Stream.just(ctx.env), len(ctx.env), required_permissions={Permission.read}) class ChunkCommand(CLICommand): @@ -1599,7 +1592,10 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: size = int(arg) if arg else 100 - return CLIFlow(lambda in_stream: in_stream | pipe.chunks(size), required_permissions={Permission.read}) + return CLIFlow( + lambda in_stream: Stream(in_stream).chunks(size).map(Stream.as_list), + required_permissions={Permission.read}, + ) class FlattenCommand(CLICommand): @@ -1646,13 +1642,7 @@ def args_info(self) -> ArgsInfo: return [] def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: - def iterable(it: Any) -> bool: - return False if isinstance(it, str) else isinstance(it, Iterable) - - def iterate(it: Any) -> JsGen: - return stream.iterate(it) if is_async_iterable(it) or iterable(it) else stream.just(it) - - return CLIFlow(lambda i: i | pipe.flatmap(iterate), required_permissions={Permission.read}) # type: ignore + return CLIFlow(lambda i: Stream(i).flatten(), required_permissions={Permission.read}) class UniqCommand(CLICommand): @@ -1709,7 +1699,7 @@ def has_not_seen(item: Any) -> bool: visited.add(item) return True - return CLIFlow(lambda in_stream: stream.filter(in_stream, has_not_seen), required_permissions={Permission.read}) + return CLIFlow(lambda in_stream: Stream(in_stream).filter(has_not_seen), required_permissions={Permission.read}) class JqCommand(CLICommand, OutputTransformer): @@ -1809,7 +1799,7 @@ def process(in_json: JsonElement) -> JsonElement: result = out[0] if len(out) == 1 else out return cast(Json, result) - return CLIFlow(lambda i: i | pipe.map(process), required_permissions={Permission.read}) # type: ignore + return CLIFlow(lambda i: Stream(i).map(process), required_permissions={Permission.read}) class KindsCommand(CLICommand, PreserveOutputFormat): @@ -1962,16 +1952,16 @@ def show(k: ComplexKind) -> bool: result: JsonElement = ( kind_to_js(model, model[kind]) if kind in model else f"No kind with this name: {kind}" ) - return 1, stream.just(result) + return 1, Stream.just(result) elif args.property_path: no_section = Section.without_section(args.property_path) result = kind_to_js(model, model.kind_by_path(no_section)) if appears_in := property_defined_in(model, no_section): result["appears_in"] = appears_in - return 1, stream.just(result) + return 1, Stream.just(result) else: result = sorted([k.fqn for k in model.kinds.values() if isinstance(k, ComplexKind) and show(k)]) - return len(model.kinds), stream.iterate(result) + return len(model.kinds), Stream.iterate(result) return CLISource.only_count(source, required_permissions={Permission.read}) @@ -1985,17 +1975,15 @@ def patch(self, arg: Optional[str], ctx: CLIContext) -> Json: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: buffer_size = 1000 func = partial(self.set_desired, arg, ctx.graph_name, self.patch(arg, ctx)) - return CLIFlow( - lambda i: i | pipe.chunks(buffer_size) | pipe.flatmap(func), required_permissions={Permission.write} - ) + return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write}) async def set_desired( - self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json] + self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json] ) -> AsyncIterator[JsonElement]: model = await self.dependencies.model_handler.load_model(graph_name) db = self.dependencies.db_access.get_graph_db(graph_name) node_ids = [] - for item in items: + async for item in items: if "id" in item: node_ids.append(item["id"]) elif isinstance(item, str): @@ -2102,7 +2090,7 @@ def patch(self, arg: Optional[str], ctx: CLIContext) -> Json: return {"clean": True} async def set_desired( - self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json] + self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json] ) -> AsyncIterator[JsonElement]: reason = f"Reason: {strip_quotes(arg)}" if arg else "No reason provided." async for elem in super().set_desired(arg, graph_name, patch, items): @@ -2123,15 +2111,13 @@ def patch(self, arg: Optional[str], ctx: CLIContext) -> Json: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: buffer_size = 1000 func = partial(self.set_metadata, ctx.graph_name, self.patch(arg, ctx)) - return CLIFlow( - lambda i: i | pipe.chunks(buffer_size) | pipe.flatmap(func), required_permissions={Permission.write} - ) + return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write}) - async def set_metadata(self, graph_name: GraphName, patch: Json, items: List[Json]) -> AsyncIterator[JsonElement]: + async def set_metadata(self, graph_name: GraphName, patch: Json, items: Stream[Json]) -> AsyncIterator[JsonElement]: model = await self.dependencies.model_handler.load_model(graph_name) db = self.dependencies.db_access.get_graph_db(graph_name) node_ids = [] - for item in items: + async for item in items: if "id" in item: node_ids.append(item["id"]) elif isinstance(item, str): @@ -2331,9 +2317,8 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa use = next(iter(format_to_use)) async def render_single(converter: ConvertFn, iss: JsStream) -> JsGen: - async with iss.stream() as streamer: - async for elem in converter(streamer): - yield elem + async for elem in converter(iss): + yield elem async def format_stream(in_stream: JsStream) -> JsGen: if use: @@ -2344,7 +2329,7 @@ async def format_stream(in_stream: JsStream) -> JsGen: else: raise ValueError(f"Unknown format: {use}") elif formatting_string: - return in_stream | pipe.map(ctx.formatter(arg)) if arg else in_stream # type: ignore + return in_stream.map(ctx.formatter(arg)) if arg else in_stream # type: ignore else: return in_stream @@ -2817,14 +2802,13 @@ def to_csv_string(lst: List[Any]) -> str: header_values = [prop.name for prop in props] yield to_csv_string(header_values) - async with in_stream.stream() as s: - async for elem in s: - if is_node(elem) or is_aggregate: - result = [] - for prop in props: - value = prop.value(elem) - result.append(value) - yield to_csv_string(result) + async for elem in in_stream: + if is_node(elem) or is_aggregate: + result = [] + for prop in props: + value = prop.value(elem) + result.append(value) + yield to_csv_string(result) async def json_table_stream(in_stream: JsStream, model: QueryModel) -> JsGen: def kind_of(path: str) -> Kind: @@ -2857,13 +2841,12 @@ def render_prop(elem: JsonElement) -> JsonElement: ], } # data columns - async with in_stream.stream() as s: - async for elem in s: - if isinstance(elem, dict) and (is_node(elem) or is_aggregate): - yield { - "id": None if is_aggregate else elem["id"], # aggregates have no id - "row": {prop.name: render_prop(prop.value(elem)) for prop in props}, - } + async for elem in in_stream: + if isinstance(elem, dict) and (is_node(elem) or is_aggregate): + yield { + "id": None if is_aggregate else elem["id"], # aggregates have no id + "row": {prop.name: render_prop(prop.value(elem)) for prop in props}, + } def markdown_stream(in_stream: JsStream) -> JsGen: chunk_size = 500 @@ -2881,7 +2864,7 @@ def extract_values(elem: JsonElement) -> List[Any | None]: result.append(value) return result - async def generate_markdown(chunk: Tuple[int, List[List[Any]]]) -> JsGen: + async def generate_markdown(chunk: Tuple[int, Stream[List[Any]]]) -> JsGen: idx, rows = chunk def to_str(elem: Any) -> str: @@ -2913,7 +2896,7 @@ def to_str(elem: Any) -> str: line += "|" yield line - for row in rows: + async for row in rows: line = "" for value, padding in zip(row, columns_padding): line += f"|{to_str(value).ljust(padding)}" @@ -2922,12 +2905,11 @@ def to_str(elem: Any) -> str: # noinspection PyUnresolvedReferences markdown_chunks = ( - in_stream - | pipe.filter(lambda x: is_node(x) or is_aggregate) - | pipe.map(extract_values) # type: ignore - | pipe.chunks(chunk_size) - | pipe.enumerate() - | pipe.flatmap(generate_markdown) # type: ignore + in_stream.filter(lambda x: is_node(x) or is_aggregate) + .map(extract_values) + .chunks(chunk_size) + .enumerate() + .flatmap(generate_markdown) # type: ignore ) return markdown_chunks @@ -2943,12 +2925,9 @@ async def load_model() -> QueryModel: model = await self.dependencies.model_handler.load_model(ctx.graph_name) return QueryModel(ctx.query or Query.empty(), model, ctx.env) - return stream.call(load_model) | pipe.flatmap(partial(json_table_stream, in_stream)) # type: ignore + return Stream.call(load_model).flatmap(partial(json_table_stream, in_stream)) # type: ignore else: - return stream.map( - in_stream, - lambda elem: fmt_json(elem) if isinstance(elem, dict) else str(elem), # type: ignore - ) + return Stream(in_stream).map(lambda elem: fmt_json(elem) if isinstance(elem, dict) else str(elem)) return CLIFlow(fmt, produces=MediaType.String, required_permissions={Permission.read}) @@ -3208,7 +3187,7 @@ async def activate_deactivate_job(job_id: str, active: bool) -> AsyncIterator[Js async def running_jobs() -> Tuple[int, JsStream]: tasks = await self.dependencies.task_handler.running_tasks() - return len(tasks), stream.iterate( + return len(tasks), Stream.iterate( {"job": t.descriptor.id, "started_at": to_json(t.task_started_at), "task-id": t.id} for t in tasks if isinstance(t.descriptor, Job) @@ -3271,7 +3250,7 @@ async def send_to_queue(task_name: str, task_args: Dict[str, str], data: Json) - await self.dependencies.forked_tasks.put((result_task, f"WorkerTask {task_name}:{task.id}")) return f"Spawned WorkerTask {task_name}:{task.id}" - return in_stream | pipe.starmap(send_to_queue, ordered=False, task_limit=self.task_limit()) # type: ignore + return in_stream.starmap(send_to_queue, ordered=False, task_limit=self.task_limit()) def load_by_id_merged( self, @@ -3281,12 +3260,12 @@ def load_by_id_merged( expected_kind: Optional[str] = None, **env: str, ) -> JsStream: - async def load_element(items: List[JsonElement]) -> AsyncIterator[JsonElement]: + async def load_element(items: JsStream) -> AsyncIterator[JsonElement]: # collect ids either from json dict or string - ids: List[str] = [i["id"] if is_node(i) else i for i in items] # type: ignore + ids: List[str] = [i["id"] if is_node(i) else i async for i in items] # if there is an entry which is not a string, use the list as is (e.g. chunked) if any(a for a in ids if not isinstance(a, str)): - for a in items: + async for a in items: yield a else: # one query to load all items that match given ids (max 1000 as defined in chunk size) @@ -3307,7 +3286,7 @@ async def load_element(items: List[JsonElement]) -> AsyncIterator[JsonElement]: async for a in crs: yield a - return stream.chunks(in_stream, 1000) | pipe.flatmap(load_element) # type: ignore + return in_stream.chunks(1000).flatmap(load_element) async def no_update(self, _: WorkerTask, future_result: Future[Json]) -> Json: return await future_result @@ -3448,17 +3427,17 @@ def setup_stream(in_stream: JsStream) -> JsStream: def with_dependencies(model: Model) -> JsStream: load = self.load_by_id_merged(model, in_stream, variables, allowed_on_kind, **ctx.env) handler = self.update_node_in_graphdb(model, **ctx.env) if expect_node_result else self.no_update - return self.send_to_queue_stream(load | pipe.map(fn), handler, True) # type: ignore + return self.send_to_queue_stream(load.map(fn), handler, True) # type: ignore # dependencies are not resolved directly (no async function is allowed here) async def load_model() -> Model: return await self.dependencies.model_handler.load_model(ctx.graph_name) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore def setup_source() -> JsStream: arg = {"args": args_parts_unquoted_parser.parse(formatter({}))} - return self.send_to_queue_stream(stream.just((command_name, {}, arg)), self.no_update, True) + return self.send_to_queue_stream(Stream.just((command_name, {}, arg)), self.no_update, True) return ( CLISource.single(setup_source, required_permissions={Permission.write}) @@ -3575,13 +3554,13 @@ def setup_stream(in_stream: JsStream) -> JsStream: def with_dependencies(model: Model) -> JsStream: load = self.load_by_id_merged(model, in_stream, variables, **ctx.env) result_handler = self.update_node_in_graphdb(model, **ctx.env) - return self.send_to_queue_stream(load | pipe.map(fn), result_handler, not ns.nowait) # type: ignore + return self.send_to_queue_stream(load.map(fn), result_handler, not ns.nowait) # type: ignore async def load_model() -> Model: return await self.dependencies.model_handler.load_model(ctx.graph_name) # dependencies are not resolved directly (no async function is allowed here) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore return CLIFlow(setup_stream, required_permissions={Permission.write}) @@ -3594,7 +3573,7 @@ def file_command() -> JsStream: elif not os.path.exists(arg): raise AttributeError(f"file does not exist: {arg}!") else: - return stream.just(arg if arg else "") + return Stream.just(arg if arg else "") return CLISource.single(file_command, MediaType.FilePath, required_permissions={Permission.admin}) @@ -3618,7 +3597,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa def upload_command() -> JsStream: if file_id in ctx.uploaded_files: file = ctx.uploaded_files[file_id] - return stream.just(f"Received file {file} of size {os.path.getsize(file)}") + return Stream.just(f"Received file {file} of size {os.path.getsize(file)}") else: raise AttributeError(f"file was not uploaded: {arg}!") @@ -3932,19 +3911,17 @@ async def write_result_to_file(ctx: CLIContext, in_stream: JsStream, file_name: async with TemporaryDirectory() as temp_dir: path = file_name if ctx.intern else os.path.join(temp_dir, uuid_str()) async with aiofiles.open(path, "w") as f: - async with in_stream.stream() as streamer: - async for out in streamer: - if isinstance(out, str): - await f.write(out + "\n") - else: - raise AttributeError("No output format is defined! Consider to use the format command.") + async for out in in_stream: + if isinstance(out, str): + await f.write(out + "\n") + else: + raise AttributeError("No output format is defined! Consider to use the format command.") yield FilePath.user_local(user=file_name, local=path).json() @staticmethod async def already_file_stream(in_stream: JsStream, file_name: str) -> AsyncIterator[JsonElement]: - async with in_stream.stream() as streamer: - async for out in streamer: - yield evolve(FilePath.from_path(out), user=Path(file_name)).json() + async for out in in_stream: + yield evolve(FilePath.from_path(out), user=Path(file_name)).json() def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: if arg is None: @@ -4040,7 +4017,7 @@ async def get_template(name: str) -> AsyncIterator[JsonElement]: async def list_templates() -> Tuple[int, Stream[str]]: templates = await self.dependencies.template_expander.list_templates() - return len(templates), stream.iterate(template_str(t) for t in templates) + return len(templates), Stream.iterate(template_str(t) for t in templates) async def put_template(name: str, template_query: str) -> AsyncIterator[str]: # try to render_console the template with dummy values and see if the search can be parsed @@ -4283,10 +4260,9 @@ async def perform_request(e: JsonElement) -> int: async def iterate_stream(in_stream: JsStream) -> AsyncIterator[JsonElement]: results: Dict[int, int] = defaultdict(lambda: 0) - async with in_stream.stream() as streamer: - async for elem in streamer: - status_code = await perform_request(elem) - results[status_code] += 1 + async for elem in in_stream: + status_code = await perform_request(elem) + results[status_code] += 1 summary = ", ".join(f"{count} requests with status {status}" for status, count in results.items()) if results: yield f"{summary} sent." @@ -4514,18 +4490,18 @@ def info(rt: RunningTask) -> JsonElement: **progress, } - return len(tasks), stream.iterate(info(t) for t in tasks if isinstance(t.descriptor, Workflow)) + return len(tasks), Stream.iterate(info(t) for t in tasks if isinstance(t.descriptor, Workflow)) async def show_log(wf_id: str) -> Tuple[int, JsStream]: rtd = await self.dependencies.db_access.running_task_db.get(wf_id) if rtd: messages = [msg.info() for msg in rtd.info_messages()] if messages: - return len(messages), stream.iterate(messages) + return len(messages), Stream.iterate(messages) else: - return 0, stream.just("No error messages for this run.") + return 0, Stream.just("No error messages for this run.") else: - return 0, stream.just(f"No workflow task with this id: {wf_id}") + return 0, Stream.just(f"No workflow task with this id: {wf_id}") def running_task_data(rtd: RunningTaskData) -> Json: result = { @@ -4539,7 +4515,7 @@ def running_task_data(rtd: RunningTaskData) -> Json: async def history_aggregation() -> JsStream: info = await self.dependencies.db_access.running_task_db.aggregated_history() - return stream.just(info) + return Stream.just(info) async def history_of(history_args: List[str]) -> Tuple[int, JsStream]: parser = NoExitArgumentParser() @@ -4558,7 +4534,7 @@ async def history_of(history_args: List[str]) -> Tuple[int, JsStream]: ) cursor: AsyncCursor = context.cursor try: - return cursor.count() or 0, stream.map(cursor, running_task_data) # type: ignore + return cursor.count() or 0, Stream(cursor).map(running_task_data) finally: cursor.close() @@ -4591,7 +4567,7 @@ async def stop_workflow(task_id: TaskId) -> AsyncIterator[str]: return CLISource.only_count(list_workflows, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -4763,7 +4739,7 @@ async def update_config(cfg_id: ConfigId) -> AsyncIterator[str]: async def list_configs() -> Tuple[int, JsStream]: ids = [i async for i in self.dependencies.config_handler.list_config_ids()] - return len(ids), stream.iterate(ids) + return len(ids), Stream.iterate(ids) args = re.split("\\s+", arg, maxsplit=2) if arg else [] if arg and len(args) == 2 and (args[0] == "show" or args[0] == "get"): @@ -4800,7 +4776,7 @@ async def list_configs() -> Tuple[int, JsStream]: return CLISource.only_count(list_configs, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -4889,7 +4865,7 @@ async def welcome() -> str: res = ctx.render_console(grid) return res - return CLISource.single(lambda: stream.just(welcome()), required_permissions={Permission.read}) # type: ignore + return CLISource.single(lambda: Stream.just(welcome()), required_permissions={Permission.read}) class TipOfTheDayCommand(CLICommand): @@ -4926,7 +4902,7 @@ async def totd() -> str: res = ctx.render_console(info) return res - return CLISource.single(lambda: stream.just(totd()), required_permissions={Permission.read}) # type: ignore + return CLISource.single(lambda: Stream.just(totd()), required_permissions={Permission.read}) class CertificateCommand(CLICommand): @@ -5004,7 +4980,7 @@ async def create_certificate( ) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5435,9 +5411,8 @@ async def app_run( raise ValueError(f"Config {config} not found.") async def stream_to_iterator() -> AsyncIterator[JsonElement]: - async with in_stream.stream() as streamer: - async for item in streamer: - yield item + async for item in in_stream: + yield item stdin = stream_to_iterator() if dry_run: @@ -5531,7 +5506,7 @@ async def stream_to_iterator() -> AsyncIterator[JsonElement]: return CLISource.no_count( partial( app_run, - in_stream=stream.empty(), + in_stream=Stream.empty(), app_name=InfraAppName(parsed.app_name), dry_run=parsed.dry_run, config=parsed.config, @@ -5552,7 +5527,7 @@ async def stream_to_iterator() -> AsyncIterator[JsonElement]: ) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5757,7 +5732,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa return CLISource.no_count(partial(self.show_user, args[1]), required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5950,7 +5925,7 @@ async def lines_iterator() -> AsyncIterator[str]: else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -6102,7 +6077,7 @@ async def sync_database_result(p: Namespace, maybe_stream: Optional[JsStream]) - async with await graph_db.search_graph_gen( QueryModel(query, fix_model, ctx.env), timeout=timedelta(weeks=200000) ) as cursor: - await sync_fn(query=query, in_stream=stream.iterate(cursor)) + await sync_fn(query=query, in_stream=Stream.iterate(cursor)) if file_output is not None: assert p.database, "No database name provided. Use the --database argument." @@ -6160,11 +6135,10 @@ def key_fn(node: Json) -> Union[str, Tuple[str, str]]: kind_by_id[node["id"]] = node["reported"]["kind"] return cast(str, node["reported"]["kind"]) - async with in_stream.stream() as streamer: - batched = BatchStream(streamer, key_fn, engine_config.batch_size, engine_config.batch_size * 10) - await update_sql( - engine_config, rcm, batched, edges, swap_temp_tables=True, drop_existing_tables=drop_existing_tables - ) + batched = BatchStream(in_stream, key_fn, engine_config.batch_size, engine_config.batch_size * 10) + await update_sql( + engine_config, rcm, batched, edges, swap_temp_tables=True, drop_existing_tables=drop_existing_tables + ) args = arg.split(maxsplit=1) if arg else [] if len(args) == 2 and args[0] == "sync": @@ -6339,16 +6313,16 @@ def parse_duration_or_int(s: str) -> Union[int, timedelta]: async def list_ts() -> Tuple[int, JsGen]: ts = await self.dependencies.db_access.time_series_db.list_time_series() - return len(ts), stream.iterate([to_js(a) for a in ts]) + return len(ts), Stream.iterate([to_js(a) for a in ts]) async def downsample() -> Tuple[int, JsGen]: ts = await self.dependencies.db_access.time_series_db.downsample() if isinstance(ts, str): - return 1, stream.just(ts) + return 1, Stream.just(ts) elif ts: - return len(ts), stream.iterate([{k: v} for k, v in ts.items()]) + return len(ts), Stream.iterate([{k: v} for k, v in ts.items()]) else: - return 1, stream.just("No time series to downsample.") + return 1, Stream.just("No time series to downsample.") args = re.split("\\s+", arg, maxsplit=1) if arg else [] if arg and len(args) == 2 and args[0] == "snapshot": @@ -6363,7 +6337,7 @@ async def downsample() -> Tuple[int, JsGen]: return CLISource.only_count(downsample, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -6459,29 +6433,28 @@ def walk_element(el: JsonElement) -> Iterator[Tuple[str, PotentialSecret]]: if r.startswith("True"): yield el, secret - async def detect_secrets_in(content: JsStream) -> JsGen: + async def detect_secrets_in(in_stream: JsStream) -> JsGen: self.configure_detect() # make sure all plugins are loaded - async with content.stream() as in_stream: - async for element in in_stream: - paths = [p for pl in parsed.path for p in pl] - paths = paths or [PropertyPath.from_list([ctx.section]) if is_node(element) else EmptyPath] - found_secrets = False - for path in paths: - if to_check_js := path.value_in(element): - for secret_string, possible_secret in walk_element(to_check_js): - found_secrets = True - if isinstance(element, dict): - element["info"] = { - "secret_detected": True, - "potential_secret": secret_string, - "secret_type": possible_secret.type, - } - yield element - break - if found_secrets: - break # no need to check other paths - if not found_secrets and not parsed.with_secrets: - yield element + async for element in in_stream: + paths = [p for pl in parsed.path for p in pl] + paths = paths or [PropertyPath.from_list([ctx.section]) if is_node(element) else EmptyPath] + found_secrets = False + for path in paths: + if to_check_js := path.value_in(element): + for secret_string, possible_secret in walk_element(to_check_js): + found_secrets = True + if isinstance(element, dict): + element["info"] = { + "secret_detected": True, + "potential_secret": secret_string, + "secret_type": possible_secret.type, + } + yield element + break + if found_secrets: + break # no need to check other paths + if not found_secrets and not parsed.with_secrets: + yield element return CLIFlow(detect_secrets_in) @@ -6546,9 +6519,9 @@ async def process_element(el: JsonElement) -> JsonElement: set_value_in_path(refinement.value, refinement.path, el) # type: ignore return el - return in_stream | pipe.map(process_element) # type: ignore + return in_stream.map(process_element) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore return CLIFlow(setup_stream, required_permissions={Permission.read}) @@ -6590,7 +6563,7 @@ async def delete_node(node_id: NodeId, keep_history: bool) -> AsyncIterator[str] fn=partial(delete_node, node_id=parsed.node_id, keep_history=parsed.keep_history), required_permissions={Permission.write}, ) - return CLISource.single(lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}) + return CLISource.single(lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}) def all_commands(d: TenantDependencies) -> List[CLICommand]: diff --git a/fixcore/fixcore/cli/model.py b/fixcore/fixcore/cli/model.py index f4afcde957..ec2cac1e45 100644 --- a/fixcore/fixcore/cli/model.py +++ b/fixcore/fixcore/cli/model.py @@ -26,8 +26,6 @@ TYPE_CHECKING, ) -from aiostream import stream -from aiostream.core import Stream from attrs import define, field from parsy import test_char, string from rich.jupyter import JupyterMixin @@ -42,6 +40,7 @@ from fixcore.query.template_expander import render_template from fixcore.types import Json, JsonElement from fixcore.util import AccessJson, uuid_str, from_utc, utc, utc_str +from fixlib.asynchronous.stream import Stream from fixlib.parse_util import l_curly_dp, r_curly_dp from fixlib.utils import get_local_tzinfo @@ -236,7 +235,7 @@ def __init__( @staticmethod def make_stream(in_stream: JsGen) -> JsStream: - return in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream) + return in_stream if isinstance(in_stream, Stream) else Stream.iterate(in_stream) @define @@ -316,7 +315,7 @@ def single( @staticmethod def empty() -> CLISource: - return CLISource.with_count(stream.empty, 0) + return CLISource.with_count(Stream.empty, 0) class CLIFlow(CLIAction): @@ -739,7 +738,7 @@ async def execute(self) -> Tuple[CLISourceContext, JsStream]: flow = await flow_action.flow(flow) return context, flow else: - return CLISourceContext(count=0), stream.empty() + return CLISourceContext(count=0), Stream.empty() class CLI(ABC): diff --git a/fixcore/fixcore/infra_apps/local_runtime.py b/fixcore/fixcore/infra_apps/local_runtime.py index c2d2376291..22506071a0 100644 --- a/fixcore/fixcore/infra_apps/local_runtime.py +++ b/fixcore/fixcore/infra_apps/local_runtime.py @@ -3,7 +3,6 @@ from pydoc import locate from typing import List, AsyncIterator, Type, Optional, Any -from aiostream import stream, pipe from jinja2 import Environment from fixcore.cli import NoExitArgumentParser, JsStream, JsGen @@ -14,6 +13,7 @@ from fixcore.infra_apps.runtime import Runtime from fixcore.service import Service from fixcore.types import Json, JsonElement +from fixlib.asynchronous.stream import Stream from fixlib.asynchronous.utils import async_lines from fixlib.durations import parse_optional_duration @@ -46,9 +46,8 @@ async def execute( Runtime implementation that runs the app locally. """ async for line in self.generate_template(graph, manifest, config, stdin, argv): - async with (await self._interpret_line(line, ctx)).stream() as streamer: - async for item in streamer: - yield item + async for item in await self._interpret_line(line, ctx): + yield item async def generate_template( self, @@ -117,4 +116,4 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> JsStream: total_nr_outputs = total_nr_outputs + (src_ctx.count or 0) command_streams.append(command_output_stream) - return stream.iterate(command_streams) | pipe.concat(task_limit=1) + return Stream.iterate(command_streams).concat(task_limit=1) # type: ignore diff --git a/fixcore/fixcore/web/api.py b/fixcore/fixcore/web/api.py index e7c830160b..06b95ccd77 100644 --- a/fixcore/fixcore/web/api.py +++ b/fixcore/fixcore/web/api.py @@ -54,7 +54,6 @@ from aiohttp.web_fileresponse import FileResponse from aiohttp.web_response import json_response from aiohttp_swagger3 import SwaggerFile, SwaggerUiSettings -from aiostream import stream from attrs import evolve from dateutil import parser as date_parser from multidict import MultiDict @@ -134,6 +133,7 @@ WorkerTaskResult, WorkerTaskInProgress, ) +from fixlib.asynchronous.stream import Stream from fixlib.asynchronous.web.ws_handler import accept_websocket, clean_ws_handler from fixlib.durations import parse_duration from fixlib.jwt import encode_jwt @@ -664,7 +664,7 @@ async def perform_benchmark_on_checks(self, request: Request, deps: TenantDepend ) return await single_result(request, to_js(result)) - async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> StreamResponse: # type: ignore + async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> StreamResponse: benchmark = request.match_info["benchmark"] graph = GraphName(request.match_info["graph_id"]) acc = request.query.get("accounts") @@ -677,8 +677,8 @@ async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> else: raise ValueError(f"Unknown action {action}. One of run or load is expected.") result_graph = results[benchmark].to_graph() - async with stream.iterate(result_graph).stream() as streamer: - return await self.stream_response_from_gen(request, streamer, count=len(result_graph)) + stream = Stream.iterate(result_graph) + return await self.stream_response_from_gen(request, stream, count=len(result_graph)) async def inspection_checks(self, request: Request, deps: TenantDependencies) -> StreamResponse: provider = request.query.get("provider") @@ -1433,7 +1433,7 @@ async def write_files(mpr: MultipartReader, tmp_dir: str) -> Dict[str, str]: if temp_dir: shutil.rmtree(temp_dir) - async def execute_parsed( # type: ignore + async def execute_parsed( self, request: Request, command: str, parsed: List[ParsedCommandLine], ctx: CLIContext ) -> StreamResponse: # what is the accepted content type @@ -1455,43 +1455,41 @@ async def execute_parsed( # type: ignore first_result = parsed[0] src_ctx, generator = await first_result.execute() # flat the results from 0 or 1 - async with generator.stream() as streamer: - gen = await force_gen(streamer) - if first_result.produces.text: - text_gen = ctx.text_generator(first_result, gen) - return await self.stream_response_from_gen( - request, - text_gen, - count=src_ctx.count, - total_count=src_ctx.total_count, - query_stats=src_ctx.stats, - additional_header=first_result.envelope, - ) - elif first_result.produces.file_path: - await mp_response.prepare(request) - await Api.multi_file_response(first_result, gen, boundary, mp_response) - await Api.close_multi_part_response(mp_response, boundary) - return mp_response - else: - raise AttributeError(f"Can not handle type: {first_result.produces}") + gen = await force_gen(generator) + if first_result.produces.text: + text_gen = ctx.text_generator(first_result, gen) + return await self.stream_response_from_gen( + request, + text_gen, + count=src_ctx.count, + total_count=src_ctx.total_count, + query_stats=src_ctx.stats, + additional_header=first_result.envelope, + ) + elif first_result.produces.file_path: + await mp_response.prepare(request) + await Api.multi_file_response(first_result, gen, boundary, mp_response) + await Api.close_multi_part_response(mp_response, boundary) + return mp_response + else: + raise AttributeError(f"Can not handle type: {first_result.produces}") elif len(parsed) > 1: await mp_response.prepare(request) for single in parsed: _, generator = await single.execute() - async with generator.stream() as streamer: - gen = await force_gen(streamer) - if single.produces.text: - with MultipartWriter(repr(single.produces), boundary) as mp: - text_gen = ctx.text_generator(single, gen) - content_type, result_stream = await result_binary_gen(request, text_gen) - mp.append_payload( - AsyncIterablePayload(result_stream, content_type=content_type, headers=single.envelope) - ) - await mp.write(mp_response, close_boundary=False) - elif single.produces.file_path: - await Api.multi_file_response(single, gen, boundary, mp_response) - else: - raise AttributeError(f"Can not handle type: {single.produces}") + gen = await force_gen(generator) + if single.produces.text: + with MultipartWriter(repr(single.produces), boundary) as mp: + text_gen = ctx.text_generator(single, gen) + content_type, result_stream = await result_binary_gen(request, text_gen) + mp.append_payload( + AsyncIterablePayload(result_stream, content_type=content_type, headers=single.envelope) + ) + await mp.write(mp_response, close_boundary=False) + elif single.produces.file_path: + await Api.multi_file_response(single, gen, boundary, mp_response) + else: + raise AttributeError(f"Can not handle type: {single.produces}") await Api.close_multi_part_response(mp_response, boundary) return mp_response else: diff --git a/fixcore/tests/fixcore/cli/command_test.py b/fixcore/tests/fixcore/cli/command_test.py index cf85a7ec38..514d0fc2c3 100644 --- a/fixcore/tests/fixcore/cli/command_test.py +++ b/fixcore/tests/fixcore/cli/command_test.py @@ -13,9 +13,9 @@ from _pytest.logging import LogCaptureFixture from aiohttp import ClientTimeout from aiohttp.web import Request -from aiostream import stream, pipe from attrs import evolve from pytest import fixture + from fixcore import version from fixcore.cli import is_node, JsStream, list_sink from fixcore.cli.cli import CLIService @@ -48,6 +48,7 @@ from fixcore.user import UsersConfigId from fixcore.util import AccessJson, utc_str, utc from fixcore.worker_task_queue import WorkerTask +from fixlib.asynchronous.stream import Stream from tests.fixcore.util_test import not_in_path @@ -279,7 +280,7 @@ async def test_list_sink(cli: CLI, dependencies: TenantDependencies) -> None: async def test_flat_sink(cli: CLI) -> None: parsed = await cli.evaluate_cli_command("json [1,2,3] | dump; json [4,5,6] | dump; json [7,8,9] | dump") expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - assert await stream.list(stream.iterate((await p.execute())[1] for p in parsed) | pipe.concat()) == expected + assert expected == await Stream.iterate(await (await p.execute())[1].collect() for p in parsed).flatten().collect() # type: ignore @pytest.mark.asyncio @@ -315,7 +316,7 @@ async def test_format(cli: CLI) -> None: async def test_workflows_command(cli: CLIService, task_handler: TaskHandlerService, test_workflow: Workflow) -> None: async def execute(cmd: str) -> List[JsonElement]: ctx = CLIContext(cli.cli_env) - return (await cli.execute_cli_command(cmd, list_sink, ctx))[0] # type: ignore + return (await cli.execute_cli_command(cmd, list_sink, ctx))[0] assert await execute("workflows list") == ["sleep_workflow", "wait_for_collect_done", "test_workflow"] assert await execute("workflows show test_workflow") == [to_js(test_workflow)] @@ -754,15 +755,14 @@ async def test_aggregation_to_count_command(cli: CLI) -> None: @pytest.mark.asyncio async def test_system_backup_command(cli: CLI) -> None: async def check_backup(res: JsStream) -> None: - async with res.stream() as streamer: - only_one = True - async for s in streamer: - path = FilePath.from_path(s) - assert path.local.exists() - # backup should have size between 30k and 1500k (adjust size if necessary) - assert 30000 < path.local.stat().st_size < 1500000 - assert only_one - only_one = False + only_one = True + async for s in res: + path = FilePath.from_path(s) + assert path.local.exists() + # backup should have size between 30k and 1500k (adjust size if necessary) + assert 30000 < path.local.stat().st_size < 1500000 + assert only_one + only_one = False await cli.execute_cli_command("system backup create", check_backup) @@ -781,10 +781,9 @@ async def test_system_restore_command(cli: CLI, tmp_directory: str) -> None: backup = os.path.join(tmp_directory, "backup") async def move_backup(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - path = FilePath.from_path(s) - os.rename(path.local, backup) + async for s in res: + path = FilePath.from_path(s) + os.rename(path.local, backup) await cli.execute_cli_command("system backup create", move_backup) ctx = CLIContext(uploaded_files={"backup": backup}) @@ -802,11 +801,10 @@ async def test_configs_command(cli: CLI, tmp_directory: str) -> None: config_file = os.path.join(tmp_directory, "config.yml") async def check_file_is_yaml(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - assert isinstance(s, str) - with open(s, "r") as file: - yaml.safe_load(file.read()) + async for s in res: + assert isinstance(s, str) + with open(s, "r") as file: + yaml.safe_load(file.read()) # create a new config entry create_result = await cli.execute_cli_command("configs set test_config t1=1, t2=2, t3=3 ", list_sink) @@ -865,19 +863,18 @@ async def test_templates_command(cli: CLI) -> None: @pytest.mark.asyncio async def test_write_command(cli: CLI) -> None: async def check_file(res: JsStream, check_content: Optional[str] = None) -> None: - async with res.stream() as streamer: - only_one = True - async for s in streamer: - fp = FilePath.from_path(s) - assert fp.local.exists() and fp.local.is_file() - assert 1 < fp.local.stat().st_size < 100000 - assert fp.user.name.startswith("write_test") - assert only_one - only_one = False - if check_content: - with open(fp.local, "r") as file: - data = file.read() - assert data == check_content + only_one = True + async for s in res: + fp = FilePath.from_path(s) + assert fp.local.exists() and fp.local.is_file() + assert 1 < fp.local.stat().st_size < 100000 + assert fp.user.name.startswith("write_test") + assert only_one + only_one = False + if check_content: + with open(fp.local, "r") as file: + data = file.read() + assert data == check_content # result can be read as json await cli.execute_cli_command("search all limit 3 | format --json | write write_test.json ", check_file) @@ -1095,14 +1092,12 @@ async def history_count(cmd: str) -> int: @pytest.mark.asyncio async def test_aggregate(dependencies: TenantDependencies) -> None: - in_stream = stream.iterate( - [{"a": 1, "b": 1, "c": 1}, {"a": 2, "b": 1, "c": 1}, {"a": 3, "b": 2, "c": 1}, {"a": 4, "b": 2, "c": 1}] - ) - - async def aggregate(agg_str: str) -> List[JsonElement]: # type: ignore + async def aggregate(agg_str: str) -> List[JsonElement]: + in_stream = Stream.iterate( + [{"a": 1, "b": 1, "c": 1}, {"a": 2, "b": 1, "c": 1}, {"a": 3, "b": 2, "c": 1}, {"a": 4, "b": 2, "c": 1}] + ) res = AggregateCommand(dependencies).parse(agg_str) - async with (await res.flow(in_stream)).stream() as flow: - return [s async for s in flow] + return [s async for s in (await res.flow(in_stream))] assert await aggregate("b as bla, c, r.d.f.name: sum(1) as count, min(a) as min, max(a) as max") == [ {"group": {"bla": 1, "c": 1, "r.d.f.name": None}, "count": 2, "min": 1, "max": 2}, @@ -1161,11 +1156,10 @@ async def execute(cmd: str, _: Type[T]) -> List[T]: return cast(List[T], result[0]) async def check_file_is_yaml(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - assert isinstance(s, str) - with open(s, "r") as file: - yaml.safe_load(file.read()) + async for s in res: + assert isinstance(s, str) + with open(s, "r") as file: + yaml.safe_load(file.read()) # install a package assert "installed successfully" in (await execute("apps install cleanup-untagged", str))[0] @@ -1235,7 +1229,7 @@ async def check_file_is_yaml(res: JsStream) -> None: async def test_user(cli: CLI) -> None: async def execute(cmd: str) -> List[JsonElement]: all_results = await cli.execute_cli_command(cmd, list_sink) - return all_results[0] # type: ignore + return all_results[0] # remove all existing users await cli.dependencies.config_handler.delete_config(UsersConfigId) @@ -1355,10 +1349,9 @@ async def execute(cmd: str, _: Type[T]) -> List[T]: dump = os.path.join(tmp_directory, "dump") async def move_dump(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - fp = FilePath.from_path(s) - os.rename(fp.local, dump) + async for s in res: + fp = FilePath.from_path(s) + os.rename(fp.local, dump) # graph export works await cli.execute_cli_command("graph export graphtest dump", move_dump) @@ -1387,28 +1380,27 @@ async def sync_and_check( ) -> Json: result: List[Json] = [] - async def check(in_: JsStream) -> None: - async with in_.stream() as streamer: - async for s in streamer: - assert isinstance(s, dict) - path = FilePath.from_path(s) - # open sqlite database - conn = sqlite3.connect(path.local) - c = conn.cursor() - tables = { - row[0] for row in c.execute("SELECT tbl_name FROM sqlite_master WHERE type='table'").fetchall() - } - if expected_tables is not None: - assert tables == expected_tables - if expected_table_count is not None: - assert len(tables) == expected_table_count - if expected_table is not None: - for table in tables: - count = c.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] - assert expected_table(table, count), f"Table {table} has {count} rows" - c.close() - conn.close() - result.append(s) + async def check(streamer: JsStream) -> None: + async for s in streamer: + assert isinstance(s, dict) + path = FilePath.from_path(s) + # open sqlite database + conn = sqlite3.connect(path.local) + c = conn.cursor() + tables = { + row[0] for row in c.execute("SELECT tbl_name FROM sqlite_master WHERE type='table'").fetchall() + } + if expected_tables is not None: + assert tables == expected_tables + if expected_table_count is not None: + assert len(tables) == expected_table_count + if expected_table is not None: + for table in tables: + count = c.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] + assert expected_table(table, count), f"Table {table} has {count} rows" + c.close() + conn.close() + result.append(s) await cli.execute_cli_command(cmd, check) assert len(result) == 1 diff --git a/fixlib/fixlib/asynchronous/stream.py b/fixlib/fixlib/asynchronous/stream.py new file mode 100644 index 0000000000..e0ff050211 --- /dev/null +++ b/fixlib/fixlib/asynchronous/stream.py @@ -0,0 +1,372 @@ +from __future__ import annotations + +import asyncio +from asyncio import TaskGroup, Task +from collections import deque +from typing import AsyncIterable, AsyncIterator, TypeVar, Optional, List, Dict, Callable, Generic, ParamSpec, TypeAlias +from typing import Iterable, Awaitable, Never, Tuple, Union + +T = TypeVar("T") +R = TypeVar("R", covariant=True) +P = ParamSpec("P") + +DirectOrAwaitable: TypeAlias = Union[T, Awaitable[T]] +IterOrAsyncIter: TypeAlias = Union[Iterable[T], AsyncIterable[T]] + +DefaultTaskLimit = 1 + + +def _async_iter(x: Iterable[T]) -> AsyncIterator[T]: + async def gen() -> AsyncIterator[T]: + for item in x: + yield item + + return gen() + + +def _to_async_iter(x: IterOrAsyncIter[T]) -> AsyncIterable[T]: + if isinstance(x, AsyncIterable): + return x + else: + return _async_iter(x) + + +def _flatmap( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: Optional[int], + ordered: bool, +) -> AsyncIterator[T]: + if ordered: + return _flatmap_ordered(source, task_limit) + else: + return _flatmap_unordered(source, task_limit) + + +async def _flatmap_unordered( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: Optional[int] = None, +) -> AsyncIterator[T]: + semaphore = asyncio.Semaphore(task_limit or DefaultTaskLimit) + queue: asyncio.Queue[T | Exception] = asyncio.Queue() + tasks_in_flight = 0 + + async def worker(sub_iter: IterOrAsyncIter[DirectOrAwaitable[T]]) -> None: + nonlocal tasks_in_flight + try: + if isinstance(sub_iter, AsyncIterable): + async for si in sub_iter: + if isinstance(si, Awaitable): + si = await si + await queue.put(si) + else: + for si in sub_iter: + if isinstance(si, Awaitable): + si = await si + await queue.put(si) + except Exception as e: + await queue.put(e) # exception: put it in the queue to be handled + finally: + semaphore.release() + tasks_in_flight -= 1 + + async with TaskGroup() as tg: + # Start worker tasks + async for src in source: + await semaphore.acquire() + tg.create_task(worker(src)) + tasks_in_flight += 1 + + # Consume items from the queue and yield them + while True: + if tasks_in_flight == 0 and queue.empty(): + break + try: + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except asyncio.CancelledError: + break + + +async def _flatmap_ordered( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: Optional[int] = None, +) -> AsyncIterator[T]: + tlf = task_limit or DefaultTaskLimit + semaphore = asyncio.Semaphore(tlf) + tasks: Dict[int, Task[None]] = {} + results: Dict[int, List[T] | Exception] = {} + next_index_to_yield = 0 + source_iter = aiter(source) + max_index_started = -1 # Highest index of tasks started + source_exhausted = False + + async def worker(sub_iter: IterOrAsyncIter[T | Awaitable[T]], index: int) -> None: + items = [] + try: + if isinstance(sub_iter, AsyncIterable): + async for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + items.append(item) + else: + for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + items.append(item) + results[index] = items + except Exception as e: + results[index] = e # Store exception to be raised later + finally: + semaphore.release() + + while True: + # Start new tasks up to task_limit ahead of next_index_to_yield + while not source_exhausted and (max_index_started - next_index_to_yield + 1) < tlf: + try: + await semaphore.acquire() + si = await anext(source_iter) + max_index_started += 1 + tasks[max_index_started] = asyncio.create_task(worker(_to_async_iter(si), max_index_started)) + except StopAsyncIteration: + source_exhausted = True + break + + if next_index_to_yield in results: + result = results.pop(next_index_to_yield) + if isinstance(result, Exception): + raise result + else: + for res in result: + yield res + # Remove completed task + tasks.pop(next_index_to_yield, None) # noqa + next_index_to_yield += 1 + else: + # Wait for the next task to complete + if next_index_to_yield in tasks: + task = tasks[next_index_to_yield] + await asyncio.wait({task}) + elif not tasks and source_exhausted: + # No more tasks to process + break + else: + # Yield control to the event loop + await asyncio.sleep(0.01) + + +class Stream(Generic[T], AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T]): + self.iterator = iterator + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + return await anext(self.iterator) + + def filter(self, fn: Callable[[T], DirectOrAwaitable[bool]]) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + async for item in self: + af = fn(item) + flag = await af if isinstance(af, Awaitable) else af + if flag: + yield item + + return Stream(gen()) + + def starmap( + self, + fn: Callable[..., DirectOrAwaitable[R]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + return self.map(lambda args: fn(*args), task_limit, ordered) # type: ignore + + def map( + self, + fn: Callable[[T], DirectOrAwaitable[R]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + async def gen() -> AsyncIterator[AsyncIterator[R | Awaitable[R]]]: + async for item in self: + res = fn(item) + yield _async_iter([res]) + + return Stream(_flatmap(gen(), task_limit, ordered)) + + def flatmap( + self, + fn: Callable[[T], DirectOrAwaitable[IterOrAsyncIter[R]]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + async def gen() -> AsyncIterator[IterOrAsyncIter[R]]: + async for item in self: + res = fn(item) + if isinstance(res, Awaitable): + res = await res + yield res + + return Stream(_flatmap(gen(), task_limit, ordered)) + + def concat(self: Stream[Stream[T]], task_limit: Optional[int] = None, ordered: bool = True) -> Stream[T]: + return self.flatmap(lambda x: x, task_limit, ordered) + + def skip(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + count = 0 + async for item in self: + if count < num: + count += 1 + continue + yield item + + return Stream(gen()) + + def take(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + count = 0 + async for item in self: + if count >= num: + break + yield item + count += 1 + + return Stream(gen()) + + def take_last(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + queue: deque[T] = deque(maxlen=num) + async for item in self: + queue.append(item) + for item in queue: + yield item + + return Stream(gen()) + + def enumerate(self) -> Stream[Tuple[int, T]]: + async def gen() -> AsyncIterator[Tuple[int, T]]: + i = 0 + async for item in self: + yield i, item + i += 1 + + return Stream(gen()) + + def chunks(self, num: int) -> Stream[Stream[T]]: + def take_n(iterator: AsyncIterator[T], n: int) -> AsyncIterator[T]: + async def n_gen() -> AsyncIterator[T]: + count = 0 + try: + while count < n: + item = await anext(iterator) + yield item + count += 1 + except StopAsyncIteration: + return + + return n_gen() + + async def gen() -> AsyncIterator[Stream[T]]: + iterator = aiter(self.iterator) + while True: + chunk_iterator = take_n(iterator, num) + try: + first_item = await anext(chunk_iterator) + except StopAsyncIteration: + break # No more items + + async def chunk_with_first() -> AsyncIterator[T]: + yield first_item + async for item in chunk_iterator: + yield item + + yield Stream(chunk_with_first()) + + return Stream(gen()) + + def flatten(self) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + async for item in self: + if isinstance(item, AsyncIterator) or hasattr(item, "__aiter__"): + async for subitem in item: + yield subitem + elif isinstance(item, Iterable): + for subitem in item: + yield subitem + else: + yield item + + return Stream(gen()) + + def cycle(self) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + items = [] + async for item in self: + yield item + items.append(item) + while items: + for item in items: + yield item + + return Stream(gen()) + + async def collect(self) -> List[T]: + return [item async for item in self] + + @staticmethod + def just(x: T | Awaitable[T]) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + if isinstance(x, Awaitable): + yield await x + else: + yield x + + return Stream(gen()) + + @staticmethod + def iterate(x: Iterable[T] | AsyncIterable[T] | AsyncIterator[T]) -> Stream[T]: + if isinstance(x, AsyncIterator): + return Stream(x) + elif isinstance(x, AsyncIterable): + return Stream(aiter(x)) + else: + return Stream(_async_iter(x)) + + @staticmethod + def empty() -> Stream[T]: + async def empty() -> AsyncIterator[Never]: + if False: + yield # noqa + + return Stream(empty()) + + @staticmethod + def for_ever(fn: Callable[[], T]) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + while True: + yield fn() + + return Stream(gen()) + + @staticmethod + def call(fn: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Stream[R]: + async def gen() -> AsyncIterator[R]: + if asyncio.iscoroutinefunction(fn): + yield await fn(*args, **kwargs) + else: + yield fn(*args, **kwargs) # type: ignore + + return Stream(gen()) + + @staticmethod + async def as_list(x: Iterable[T] | AsyncIterable[T] | AsyncIterator[T]) -> List[T]: + if isinstance(x, AsyncIterator): + return [item async for item in x] + elif isinstance(x, AsyncIterable): + return [item async for item in aiter(x)] + else: + return [item for item in x] diff --git a/fixlib/test/asynchronous/stream_test.py b/fixlib/test/asynchronous/stream_test.py new file mode 100644 index 0000000000..5b4e6efc00 --- /dev/null +++ b/fixlib/test/asynchronous/stream_test.py @@ -0,0 +1,72 @@ +import asyncio +from typing import AsyncIterator + +from fixlib.asynchronous.stream import Stream + + +async def example_gen() -> AsyncIterator[int]: + for i in range(5): + yield i + + +def example_stream() -> Stream: + return Stream(example_gen()) + + +async def test_just() -> None: + assert await Stream.just(1).collect() == [1] + + +async def test_iterate() -> None: + assert await Stream.iterate([1, 2, 3]).collect() == [1, 2, 3] + assert await Stream.iterate(example_gen()).collect() == [0, 1, 2, 3, 4] + assert await Stream.iterate(example_stream()).collect() == [0, 1, 2, 3, 4] + + +async def test_filter() -> None: + assert await example_stream().filter(lambda x: x % 2).collect() == [1, 3] + + +async def test_map() -> None: + async def fn(x: int) -> int: + await asyncio.sleep(0) + return x * 2 + + assert await example_stream().map(lambda x: x * 2).collect() == [0, 2, 4, 6, 8] + assert await example_stream().map(fn).collect() == [0, 2, 4, 6, 8] + + +async def test_flatmap() -> None: + async def gen(x: int): + await asyncio.sleep(0) + for i in range(2): + yield x * 2 + + assert await example_stream().flatmap(gen).collect() == [0, 0, 2, 2, 4, 4, 6, 6, 8, 8] + + +async def test_take() -> None: + assert await example_stream().take(3).collect() == [0, 1, 2] + + +async def test_take_last() -> None: + assert await example_stream().take_last(3).collect() == [2, 3, 4] + + +async def test_skip() -> None: + assert await example_stream().skip(2).collect() == [2, 3, 4] + assert await example_stream().skip(10).collect() == [] + + +async def test_call() -> None: + def fn(foo: int, bla: str) -> int: + return 123 + + def with_int(foo: int) -> int: + return foo + 1 + + assert await Stream.call(fn, 1, "bla").map(with_int).collect() == [124] + + +# async def test_chunks(example_stream: Stream) -> None: +# assert await example_stream.chunks(2).collect() == [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]