Skip to content

Commit

Permalink
🐛 fix sistana enter_instantly & stargazing
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 31, 2024
1 parent 00c30c0 commit df92d34
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
34 changes: 22 additions & 12 deletions src/arclet/alconna/_stargazing/compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Sequence, Any, overload
from typing import Sequence, Any

from arclet.alconna.base import Subcommand, Option, HeadResult
from arclet.alconna.base import Subcommand, Option, HeadResult, SPECIAL_OPTIONS
from arclet.alconna.exceptions import InvalidArgs, InvalidParam, ParamsUnmatched, UnexpectedElement, ArgumentMissing, \
NullMessage
from arclet.alconna.core import Alconna
Expand Down Expand Up @@ -41,7 +41,7 @@ def _alc_args_to_fragments(args: _Args) -> list[Fragment]:
return fragments


def step(node: Subcommand | Option, upper: SubcommandPattern):
def step(node: Subcommand | Option, upper: SubcommandPattern, _global_bind: tuple[dict, dict]):
if isinstance(node, Subcommand):
pat = upper.subcommand(
node.name,
Expand All @@ -50,8 +50,10 @@ def step(node: Subcommand | Option, upper: SubcommandPattern):
soft_keyword=node.soft_keyword,
separators=node.separators,
)
pat._options_bind.maps.append(_global_bind[0])
pat._subcommands_bind.maps.append(_global_bind[1])
for option in node.options:
step(option, pat)
step(option, pat, _global_bind)
return pat
else:
header_fragment = None
Expand Down Expand Up @@ -91,7 +93,17 @@ def into_sistana(cmd: Alconna):
separators=cmd.separators,
)
for option in cmd.options:
step(option, pat)
if isinstance(option, SPECIAL_OPTIONS):
pat.subcommand(
option.name,
*_alc_args_to_fragments(option.args),
aliases=option.aliases,
soft_keyword=False,
separators=option.separators,
enter_instantly=True,
)
else:
step(option, pat, (pat._options_bind.maps[0], pat._subcommands_bind.maps[0])) # type: ignore
return pat


Expand All @@ -100,12 +112,12 @@ def _reason_raise_alc_exception(reason: LoopflowExitReason):
return

if reason in {
LoopflowExitReason.unsatisfied,
LoopflowExitReason.previous_unsatisfied,
# LoopflowExitReason.unsatisfied,
# LoopflowExitReason.previous_unsatisfied,
LoopflowExitReason.unsatisfied_switch_option,
LoopflowExitReason.unsatisfied_switch_subcommand,
}:
return ParamsUnmatched(f"LoopflowDescription: {reason.value}")
raise ParamsUnmatched(f"LoopflowDescription: {reason.value}")

if reason in {
LoopflowExitReason.out_of_data_subcommand,
Expand Down Expand Up @@ -135,7 +147,6 @@ def _reason_raise_alc_exception(reason: LoopflowExitReason):
def dump_arparma(
alc: Alconna,
snapshot: AnalyzeSnapshot,
buffer: Buffer,
message: Sequence[Any],
matched: bool = True,
head_matched: bool = True
Expand Down Expand Up @@ -165,7 +176,7 @@ def dump_arparma(
args_result=args_result,
value_result=value_result,
)
arp.buffer = buffer # type: ignore
arp.buffer = snapshot.command # type: ignore
return arp


Expand All @@ -185,7 +196,7 @@ def _parse(self: Alconna, message: Sequence[Any], _) -> Arparma:
reason = analyzer.loopflow(snapshot, buffer)
head_matched = reason not in {LoopflowExitReason.prefix_mismatch, LoopflowExitReason.header_mismatch}
_reason_raise_alc_exception(reason)
return dump_arparma(self, snapshot, buffer, message, True, head_matched)
return dump_arparma(self, snapshot, message, True, head_matched)


_OLD_PARSE = Alconna._parse
Expand Down Expand Up @@ -222,7 +233,6 @@ def _sistana_debug(alc: Alconna, message):
dump_arparma(
alc,
snapshot,
buffer,
message,
res == LoopflowExitReason.completed,
res in {LoopflowExitReason.prefix_mismatch, LoopflowExitReason.header_mismatch},
Expand Down
9 changes: 6 additions & 3 deletions src/arclet/alconna/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def handle_argv():

def add_builtin_options(options: list[Option | Subcommand], router: Router, conf: Config) -> None:
if "help" not in conf.disable_builtin_options:
options.append(Help("|".join(conf.builtin_option_name["help"]), dest="$help", help_text=lang.require("builtin", "option_help"), soft_keyword=False)) # noqa: E501
options.append(hlp := Help("|".join(conf.builtin_option_name["help"]), dest="$help", help_text=lang.require("builtin", "option_help"), soft_keyword=False)) # noqa: E501

@router.route(hlp.name)
@router.route("$help")
def _(command: Alconna, arp: Arparma):
_help_param = [str(i) for i in arp.buffer if str(i) not in conf.builtin_option_name["help"]]
Expand All @@ -62,7 +63,7 @@ def _(command: Alconna, arp: Arparma):

if "shortcut" not in conf.disable_builtin_options:
options.append(
Shortcut(
sct := Shortcut(
"|".join(conf.builtin_option_name["shortcut"]),
Args.action("delete|list", optional=True).name(str, optional=True).command(str, optional=True),
dest="$shortcut",
Expand All @@ -71,6 +72,7 @@ def _(command: Alconna, arp: Arparma):
)
)

@router.route(sct.name)
@router.route("$shortcut")
def _(command: Alconna, arp: Arparma):
res = arp.query[OptionResult]("$shortcut", force_return=True)
Expand All @@ -90,8 +92,9 @@ def _(command: Alconna, arp: Arparma):
router._routes.pop("$shortcut", None)

if "completion" not in conf.disable_builtin_options:
options.append(Completion("|".join(conf.builtin_option_name["completion"]), dest="$completion", help_text=lang.require("builtin", "option_completion"), soft_keyword=False)) # noqa: E501
options.append(comp := Completion("|".join(conf.builtin_option_name["completion"]), dest="$completion", help_text=lang.require("builtin", "option_completion"), soft_keyword=False)) # noqa: E501

@router.route(comp.name)
@router.route("$completion")
def _(command: Alconna, arp: Arparma):
rest = arp.buffer
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/sistana/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
else:
current_track.complete(mix)

if not enter_forward and snapshot.stage_satisfied or not subcommand.enter_instantly:
if not enter_forward and snapshot.stage_satisfied or subcommand.enter_instantly:
token.apply()
mix.complete()

Expand Down
9 changes: 5 additions & 4 deletions src/arclet/alconna/sistana/model/pattern.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import ChainMap
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Iterable, MutableMapping
Expand Down Expand Up @@ -27,15 +28,15 @@ class SubcommandPattern:

prefixes: Trie[str] | None = field(default=None)
compact_header: bool = False
enter_instantly: bool = True
enter_instantly: bool = False
header_fragment: _Fragment | None = None

_options: list[OptionPattern] = field(default_factory=list)
_compact_keywords: Trie[str] | None = field(default=None)
_exit_options: list[str] = field(default_factory=list)

_options_bind: MutableMapping[str, OptionPattern] = field(default_factory=dict)
_subcommands_bind: MutableMapping[str, SubcommandPattern] = field(default_factory=dict)
_options_bind: ChainMap[str, OptionPattern] = field(default_factory=lambda: ChainMap())
_subcommands_bind: ChainMap[str, SubcommandPattern] = field(default_factory=lambda: ChainMap())

@classmethod
def build(
Expand Down Expand Up @@ -100,7 +101,7 @@ def subcommand(
separators: str = SEPARATORS,
compact_header: bool = False,
compact_aliases: bool = False,
enter_instantly: bool = True,
enter_instantly: bool = False,
header_fragment: _Fragment | None = None,
):
preset = Preset(Track(fragments, header=header_fragment), {})
Expand Down
10 changes: 5 additions & 5 deletions tests/sistana/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class LoopflowTest:
exit_reason: LoopflowExitReason

def expect(self, *expected: LoopflowExitReason):
assert self.exit_reason in expected
assert self.exit_reason in expected, f"Expected {expected}, got {self.exit_reason}"

def expect_completed(self):
self.expect(LoopflowExitReason.completed)

def expect_uncompleted(self):
assert self.exit_reason != LoopflowExitReason.completed
assert self.exit_reason != LoopflowExitReason.completed, f"Expected uncompleted, got {self.exit_reason}"


@dataclass
Expand All @@ -49,13 +49,13 @@ def mix(self):
return MixTest(self.snapshot.mix)

def expect_determined(self, expected: bool = True):
assert self.snapshot.determined == expected
assert self.snapshot.determined == expected, f"Expected determined {expected}, got {self.snapshot.determined}"

def expect_state(self, *states: ProcessingState):
assert self.snapshot.state in (states or (ProcessingState.COMMAND,))
assert self.snapshot.state in (states or (ProcessingState.COMMAND,)), f"Expected {states}, got {self.snapshot.state}"

def expect_endpoint(self, *expected: str):
assert self.snapshot.endpoint == expected
assert self.snapshot.endpoint == expected, f"Expected {expected}, got {self.snapshot.endpoint}"


@dataclass
Expand Down

0 comments on commit df92d34

Please sign in to comment.