Skip to content

Commit

Permalink
sistana: improve prof
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyElaina committed Oct 4, 2024
1 parent bd9574a commit 66cd737
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 55 deletions.
57 changes: 36 additions & 21 deletions src/arclet/alconna/sistana/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Analyzer(Generic[T]):

def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExitReason:
while True:
if snapshot.determined and self.complete_on_determined and snapshot.stage_satisfied:
if self.complete_on_determined and snapshot.determined and snapshot.stage_satisfied:
return LoopflowExitReason.completed

context = snapshot.context
Expand Down Expand Up @@ -85,7 +85,6 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
buffer.pushleft(token.val[len(prefix) :])

snapshot.set_alter(current.parent.header())
# snapshot.current = current.parent.header() # 直接进 header.
elif pointer_type is PointerRole.HEADER:
if not isinstance(token.val, str):
return LoopflowExitReason.header_expect_str
Expand All @@ -103,7 +102,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
return LoopflowExitReason.header_mismatch

next_current = current.parent
track = mix.tracks[next_current]
track = mix.tracks[next_current.data]
track.emit_header(mix, token.val)

snapshot.unset_alter()
Expand All @@ -120,23 +119,23 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
target_ref = current.subcommand(subcommand.header)
mix.update(target_ref, subcommand.preset)

target_track = mix.tracks[target_ref]
target_track = mix.tracks[target_ref.data]
target_track.emit_header(mix, token.val)
snapshot.pop_pendings()

snapshot.context = subcommand
snapshot.context = target_ref, subcommand
snapshot.update_pending()
continue
elif not subcommand.soft_keyword:
return LoopflowExitReason.unsatisfied_switch_subcommand
elif (option_info := snapshot.get_option(token.val)) is not None:
owned_subcommand_ref, option_keyword = option_info
owned_subcommand = snapshot.traverses[owned_subcommand_ref]
owned_subcommand = snapshot.traverses[owned_subcommand_ref.data]
target_option = owned_subcommand._options_bind[option_keyword]

if not target_option.soft_keyword or snapshot.stage_satisfied:
option_ref = owned_subcommand_ref.option(option_keyword)
track = mix.tracks[option_ref]
track = mix.tracks[option_ref.data]

if not target_option.allow_duplicate and track.emitted:
return LoopflowExitReason.option_duplicated_prohibited
Expand All @@ -154,7 +153,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
elif pointer_type is PointerRole.OPTION:
if token.val in context._subcommands_bind:
subcommand = context._subcommands_bind[token.val]
track = mix.tracks[current]
track = mix.tracks[current.data]

if not track.satisfied:
if not subcommand.soft_keyword:
Expand All @@ -168,32 +167,48 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
token.apply()
snapshot.complete()

# context hard switch
mix.update(current, subcommand.preset)
target_ref = current.subcommand(subcommand.header)
mix.update(target_ref, subcommand.preset)

target_track = mix.tracks[target_ref.data]
target_track.emit_header(mix, token.val)
snapshot.pop_pendings()

snapshot.context = subcommand
snapshot.context = target_ref, subcommand
snapshot.update_pending()
continue
elif not subcommand.soft_keyword: # and not snapshot.stage_satisfied
return LoopflowExitReason.unsatisfied_switch_subcommand

elif (option_info := snapshot.get_option(token.val)) is not None:
track = mix.tracks[current]
owned_subcommand_ref, option_ref = option_info
owned_subcommand = snapshot.traverses[owned_subcommand_ref]
target_option = owned_subcommand._options_bind[option_ref]
current_track = mix.tracks[current.data]

if not track.satisfied:
owned_subcommand_ref, option_name = option_info
owned_subcommand = snapshot.traverses[owned_subcommand_ref.data]
target_option = owned_subcommand._options_bind[option_name]
target_option_ref = owned_subcommand_ref.option(option_name)

if not current_track.satisfied:
if not target_option.soft_keyword:
mix.reset_track(current)
return LoopflowExitReason.previous_unsatisfied
else:
track.complete(mix)
current_track.complete(mix)
snapshot.unset_alter()


if not target_option.soft_keyword or snapshot.stage_satisfied:
# 这里的逻辑基本上和上面的一致。
target_track = mix.tracks[target_option_ref.data]
if not target_option.allow_duplicate and target_track.emitted:
return LoopflowExitReason.option_duplicated_prohibited

if target_track:
target_track.reset()
snapshot.set_alter(target_option_ref)
snapshot._ref_cache_option[target_option_ref] = target_option

target_track.emit_header(mix, token.val)
token.apply()
continue
# else: 进了 track process.

Expand Down Expand Up @@ -232,7 +247,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
# NOTE: 这里其实有个有趣的点需要提及:pattern 中的 subcommands, options 和这里的 compacts 都是多对一的关系,
# 所以如果要取 track 之类的,就需要先绕个路,因为数据结构上的主索引总是采用的 node 上的单个 keyword。
opt = context._options_bind[prefix]
track = mix.tracks[current.option(opt.keyword)]
track = mix.tracks[current.option(opt.keyword).data]

redirect = track.assignable or opt.allow_duplicate
# 这也排除了没有 fragments 设定的情况,因为这里 token.val 是形如 "-xxx11112222",已经传了一个 fragment 进去。
Expand All @@ -251,7 +266,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
# else: 进了 track process.

if pointer_type is PointerRole.SUBCOMMAND:
track = mix.tracks[current]
track = mix.tracks[current.data]

try:
response = track.forward(mix, buffer, context.separators)
Expand All @@ -270,7 +285,7 @@ def loopflow(self, snapshot: AnalyzeSnapshot, buffer: Buffer[T]) -> LoopflowExit
# 即使有,上面也已经给你处理了。
elif pointer_type is PointerRole.OPTION:
# option fragments 的处理是原子性的,整段成功才会 apply changes,否则会被 reset。
track = mix.tracks[current]
track = mix.tracks[current.data]
opt = snapshot._ref_cache_option[current]

try:
Expand Down
29 changes: 15 additions & 14 deletions src/arclet/alconna/sistana/model/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any

from .pointer import Pointer
from .pointer import Pointer, PointerData, PointerRole
from ..err import CaptureRejected, ReceivePanic, TransformPanic, ValidateRejected
from ..some import Value
from .fragment import _Fragment, assert_fragments_order
Expand Down Expand Up @@ -31,20 +31,17 @@ def __init__(self, fragments: tuple[_Fragment, ...], header: _Fragment | None =
def satisfied(self):
return self.cursor >= self.max_length or self.fragments[0].default is not None or self.fragments[0].variadic

def apply_defaults(self, mix: Mix):
for frag in self.fragments:
def complete(self, mix: Mix):
if self.cursor >= self.max_length:
return

for frag in self.fragments[self.cursor:]:
if frag.name not in mix.assignes and frag.default is not None:
mix.assignes[frag.name] = frag.default.value

if self.header is not None and self.header.name not in mix.assignes and self.header.default is not None:
mix.assignes[self.header.name] = self.header.default.value

def complete(self, mix: Mix):
if not self.fragments:
return

self.apply_defaults(mix)

first = self.fragments[-1]
if first.variadic and first.name not in mix.assignes:
mix.assignes[first.name] = []
Expand Down Expand Up @@ -186,7 +183,7 @@ class Mix:
__slots__ = ("assignes", "tracks")

assignes: dict[str, Any]
tracks: dict[Pointer, Track]
tracks: dict[PointerData, Track]

def __init__(self):
self.assignes = {}
Expand All @@ -197,15 +194,19 @@ def complete(self):
track.complete(self)

def reset_track(self, ref: Pointer):
track = self.tracks[ref]
track = self.tracks[ref.data]
track.reset()

@property
def satisfied(self):
return all(track.satisfied for track in self.tracks.values())
for track in self.tracks.values():
if not track.satisfied:
return False

return True

def update(self, root: Pointer, preset: Preset):
self.tracks[root] = preset.subcommand_track.copy()
self.tracks[root.data] = preset.subcommand_track.copy()

for track_id, track in preset.option_tracks.items():
self.tracks[root.option(track_id)] = track.copy()
self.tracks[root.data + ((PointerRole.OPTION, track_id),)] = track.copy()
4 changes: 2 additions & 2 deletions src/arclet/alconna/sistana/model/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .fragment import assert_fragments_order
from .mix import Preset, Track
from .pointer import Pointer
from .pointer import Pointer, PointerRole
from .snapshot import AnalyzeSnapshot

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,7 +71,7 @@ def root_ref(self):
return Pointer().subcommand(self.header)

def create_snapshot(self, ref: Pointer):
snapshot = AnalyzeSnapshot(main_ref=self.root_ref, alter_ref=ref, traverses={self.root_ref: self})
snapshot = AnalyzeSnapshot(main_ref=self.root_ref, alter_ref=ref, traverses={((PointerRole.SUBCOMMAND, self.header),): self})
snapshot.mix.update(self.root_ref, self.preset)
return snapshot

Expand Down
6 changes: 4 additions & 2 deletions src/arclet/alconna/sistana/model/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class PointerRole(int, Enum):
PREFIX = 3

PointerContent = Tuple[PointerRole, str]
PointerData = Tuple[PointerContent, ...]


HEADER_STR = "::"
PREFIX_STR = "^"
Expand All @@ -19,9 +21,9 @@ class PointerRole(int, Enum):
class Pointer:
__slots__ = ("data",)

data: tuple[PointerContent, ...]
data: PointerData

def __init__(self, data: tuple[PointerContent, ...] = ()) -> None:
def __init__(self, data: PointerData = ()) -> None:
self.data = data

def subcommand(self, name: str):
Expand Down
38 changes: 22 additions & 16 deletions src/arclet/alconna/sistana/model/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from typing import TYPE_CHECKING

from arclet.alconna.sistana.model.mix import Mix
from .mix import Mix
from .pointer import Pointer, PointerData

if TYPE_CHECKING:
from .pattern import OptionPattern, SubcommandPattern
from .pointer import Pointer


# @dataclass
Expand All @@ -21,7 +21,7 @@ class AnalyzeSnapshot:
"_alter_ref",
)

traverses: dict[Pointer, SubcommandPattern]
traverses: dict[PointerData, SubcommandPattern]
endpoint: Pointer | None
mix: Mix

Expand All @@ -31,7 +31,7 @@ class AnalyzeSnapshot:
_pending_options: list[tuple[Pointer, str, set[str]]]
_ref_cache_option: dict[Pointer, OptionPattern]

def __init__(self, main_ref: Pointer, alter_ref: Pointer | None, traverses: dict[Pointer, SubcommandPattern]):
def __init__(self, main_ref: Pointer, alter_ref: Pointer | None, traverses: dict[PointerData, SubcommandPattern]):
self._main_ref = main_ref
self._alter_ref = alter_ref

Expand All @@ -49,33 +49,34 @@ def __init__(self, main_ref: Pointer, alter_ref: Pointer | None, traverses: dict
def current_ref(self):
return self._alter_ref or self._main_ref

def set_alter(self, option: Pointer):
self._alter_ref = option
def set_alter(self, ref: Pointer):
self._alter_ref = ref

def unset_alter(self):
self._alter_ref = None

@property
def context(self):
return self.traverses[self._main_ref]
return self.traverses[self._main_ref.data]

@context.setter
def context(self, value: SubcommandPattern):
self._main_ref = self._main_ref.subcommand(value.header)
self.traverses[self._main_ref] = value
def context(self, value: tuple[Pointer, SubcommandPattern]):
ref, cmd = value
self._main_ref = ref
self.traverses[ref.data] = cmd

@property
def determined(self):
return self.endpoint is not None

@property
def stage_satisfied(self):
conda = self.mix.tracks[self._main_ref].satisfied
conda = self.mix.tracks[self._main_ref.data].satisfied
if conda:
subcommand = self.traverses[self._main_ref]
subcommand = self.traverses[self._main_ref.data]
for ref, keyword, _ in self._pending_options:
if keyword in subcommand._exit_options:
if not self.mix.tracks[ref.option(keyword)].satisfied:
if not self.mix.tracks[ref.option(keyword).data].satisfied:
return False

return conda
Expand All @@ -85,9 +86,14 @@ def determine(self, endpoint: Pointer | None = None):

def update_pending(self):
subcommand_ref = self._main_ref
subcommand_pattern = self.traverses[subcommand_ref]
for option in subcommand_pattern._options:
self._pending_options.append((subcommand_ref, option.keyword, {option.keyword, *option.aliases}))
subcommand_pattern = self.traverses[subcommand_ref.data]
# for option in subcommand_pattern._options:
# self._pending_options.append((subcommand_ref, option.keyword, {option.keyword, *option.aliases}))

self._pending_options.extend([
(subcommand_ref, option.keyword, {option.keyword, *option.aliases})
for option in subcommand_pattern._options
])

def get_option(self, trigger: str):
for subcommand_ref, option_keyword, triggers in self._pending_options:
Expand Down

0 comments on commit 66cd737

Please sign in to comment.