Skip to content

Commit

Permalink
💥 version 2.0.0a2
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Nov 11, 2023
1 parent acd30eb commit b8bc532
Show file tree
Hide file tree
Showing 19 changed files with 168 additions and 379 deletions.
47 changes: 11 additions & 36 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,21 @@
import time
from arclet.alconna import Alconna, Args, ANY, command_manager, namespace
from arclet.alconna import Alconna, Option, Args, command_manager
import cProfile
import pstats


class Plain:
type = "Plain"
text: str

def __init__(self, t: str):
self.text = t

def __repr__(self):
return self.text


class At:
type = "At"
target: int

def __init__(self, t: int):
self.target = t

def __repr__(self):
return f"At:{self.target}"


with namespace("test") as np:
np.enable_message_cache = False
np.to_text = lambda x: x.text if x.__class__ is Plain else None
alc = Alconna(
["."],
"test",
Args["bar", ANY]
)
alc = Alconna(
"test",
Option("--foo", Args["f", str]),
Option("--bar", Args["b", str]),
Option("--baz", Args["z", str]),
Option("--qux", Args["q", str]),
)

argv = command_manager.resolve(alc)
analyser = command_manager.require(alc)
print(alc)
msg = [Plain(".test"), At(124)]
msg = ["test --qux 123"]

print(alc.parse(msg))
count = 20000

if __name__ == "__main__":
Expand All @@ -62,8 +39,6 @@ def __repr__(self):

print(f"Alconna: {li / count} ns per loop with {count} loops")

command_manager.records.clear()

prof = cProfile.Profile()
prof.enable()
for _ in range(count):
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from .typing import UnpackVar as UnpackVar
from .typing import Up as Up

__version__ = "2.0.0a1"
__version__ = "2.0.0a2"

# backward compatibility
Arpamar = Arparma
Expand Down
63 changes: 17 additions & 46 deletions src/arclet/alconna/_internal/_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,6 @@
_SPECIAL = {"help": handle_help, "shortcut": handle_shortcut, "completion": handle_completion}


def _compile_opts(option: Option, data: dict[str, Option | list[Option] | SubAnalyser]):
"""处理选项
Args:
option (Option): 选项
data (dict[str, Sentence | Option | list[Option] | SubAnalyser]): 编译的节点
"""
for alias in option.aliases:
if li := data.get(alias):
if isinstance(li, SubAnalyser):
continue
if isinstance(li, list):
li.append(option)
li.sort(key=lambda x: x.priority, reverse=True)
else:
data[alias] = sorted([li, option], key=lambda x: x.priority, reverse=True)
else:
data[alias] = option


def default_compiler(analyser: SubAnalyser, pids: set[str]):
"""默认的编译方法
Expand All @@ -71,7 +51,10 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]):
if isinstance(opts, Option) and not isinstance(opts, (Help, Shortcut, Completion)):
if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): # noqa: E501
analyser.compact_params.append(opts)
_compile_opts(opts, analyser.compile_params) # type: ignore
for alias in opts.aliases:
if alias in analyser.compile_params and isinstance(analyser.compile_params[alias], SubAnalyser):
continue
analyser.compile_params[alias] = opts
if opts.default is not Empty:
analyser.default_opt_result[opts.dest] = (opts.default, opts.action)
pids.update(opts.aliases)
Expand All @@ -82,7 +65,7 @@ def default_compiler(analyser: SubAnalyser, pids: set[str]):
default_compiler(sub, pids)
if not set(analyser.command.separators).issuperset(opts.separators):
analyser.compact_params.append(sub)
if sub.command.default:
if sub.command.default is not Empty:
analyser.default_sub_result[opts.dest] = sub.command.default


Expand All @@ -96,7 +79,7 @@ class SubAnalyser(Generic[TDC]):
"""命令是否只有主参数"""
need_main_args: bool = field(default=False)
"""是否需要主参数"""
compile_params: dict[str, Option | list[Option] | SubAnalyser[TDC]] = field(default_factory=dict)
compile_params: dict[str, Option | SubAnalyser[TDC]] = field(default_factory=dict)
"""编译的节点"""
compact_params: list[Option | SubAnalyser[TDC]] = field(default_factory=list)
"""可能紧凑的需要逐个解析的节点"""
Expand Down Expand Up @@ -157,12 +140,12 @@ def reset(self):
self.value_result = None
self.header_result = None

def process(self, argv: Argv[TDC]) -> Self:
def process(self, argv: Argv[TDC], trigger: str | None = None) -> Self:
"""处理传入的参数集合
Args:
argv (Argv[TDC]): 命令行参数
trigger (str | None, optional): 触发词. Defaults to None.
Returns:
Self: 自身
Expand All @@ -171,11 +154,12 @@ def process(self, argv: Argv[TDC]) -> Self:
FuzzyMatchSuccess: 模糊匹配成功
"""
sub = argv.context = self.command
name, _ = argv.next(sub.separators)
if name != sub.name: # 先匹配节点名称
if argv.fuzzy_match and levenshtein(name, sub.name) >= config.fuzzy_threshold:
raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=sub.name))
raise ParamsUnmatched(lang.require("subcommand", "name_error").format(target=name, source=sub.name))
if not trigger:
name, _ = argv.next(sub.separators)
if name != sub.name: # 先匹配节点名称
if argv.fuzzy_match and levenshtein(name, sub.name) >= config.fuzzy_threshold:
raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=name, target=sub.name))
raise ParamsUnmatched(lang.require("subcommand", "name_error").format(target=name, source=sub.name))

self.value_result = sub.action.value
return self.analyse(argv)
Expand Down Expand Up @@ -221,8 +205,6 @@ class Analyser(SubAnalyser[TDC], Generic[TDC]):

command: Alconna
"""命令实例"""
used_tokens: set[int]
"""已使用的token"""
command_header: Header
"""命令头部"""

Expand All @@ -235,15 +217,10 @@ def __init__(self, alconna: Alconna[TDC], compiler: TCompile | None = None):
"""
super().__init__(alconna)
self.fuzzy_match = alconna.meta.fuzzy_match
self.used_tokens = set()
self.command_header = Header.generate(alconna.command, alconna.prefixes, alconna.meta.compact)
compiler = compiler or default_compiler
compiler(self, command_manager.resolve(self.command).param_ids)

def _clr(self):
self.used_tokens.clear()
super()._clr()

def __repr__(self):
return f"<{self.__class__.__name__} of {self.command.path}>"

Expand Down Expand Up @@ -283,15 +260,14 @@ def shortcut(
if reg:
_handle_shortcut_reg(argv, reg.groups(), reg.groupdict())
argv.bak_data = argv.raw_data.copy()
if argv.message_cache:
argv.token = argv.generate_token(argv.raw_data)
return self.process(argv)

def process(self, argv: Argv[TDC]) -> Arparma[TDC]:
def process(self, argv: Argv[TDC], trigger=None) -> Arparma[TDC]:
"""主体解析函数, 应针对各种情况进行解析
Args:
argv (Argv[TDC]): 命令行参数
trigger (str | None, optional): 触发词. Defaults to None.
Returns:
Arparma[TDC]: Arparma 解析结果
Expand All @@ -301,8 +277,6 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]:
InvalidParam: 参数不匹配
ArgumentMissing: 参数缺失
"""
if argv.message_cache and argv.token in self.used_tokens and (res := command_manager.get_record(argv.token)):
return res
try:
self.header_result = analyse_header(self.command_header, argv)
except InvalidParam as e:
Expand Down Expand Up @@ -336,7 +310,7 @@ def process(self, argv: Argv[TDC]) -> Arparma[TDC]:
if fail := self.analyse(argv):
return fail

if argv.done and (not self.need_main_args or self.args_result):
if argv.current_index == argv.ndata and (not self.need_main_args or self.args_result):
return self.export(argv)

rest = argv.release()
Expand Down Expand Up @@ -407,9 +381,6 @@ def export(
result.main_args = self.args_result
result.options = self.options_result
result.subcommands = self.subcommands_result
if argv.message_cache:
command_manager.record(argv.token, result)
self.used_tokens.add(argv.token)
self.reset()
return result # type: ignore

Expand Down
14 changes: 1 addition & 13 deletions src/arclet/alconna/_internal/_argv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,9 @@ class Argv(Generic[TDC]):
"""备份的原始数据"""
raw_data: list[str | Any] = field(init=False)
"""原始数据"""
token: int = field(init=False)
"""命令的token"""
origin: TDC = field(init=False)
"""原始命令"""
_sep: tuple[str, ...] | None = field(init=False)

_cache: ClassVar[dict[type, dict[str, Any]]] = {}

def __post_init__(self):
Expand All @@ -70,16 +67,11 @@ def reset(self):
self.ndata = 0
self.bak_data = []
self.raw_data = []
self.token = 0
self.origin = "None"
self._sep = None
self._next = None
self.context = None

@staticmethod
def generate_token(data: list) -> int:
"""命令的`token`的生成函数"""
return hash(repr(data))

@property
def done(self) -> bool:
"""命令是否解析完毕"""
Expand Down Expand Up @@ -112,8 +104,6 @@ def build(self, data: TDC) -> Self:
raise NullMessage(lang.require("argv", "null_message").format(target=data))
self.ndata = i
self.bak_data = raw_data.copy()
if self.message_cache:
self.token = self.generate_token(raw_data)
return self

def addon(self, data: Iterable[str | Any]) -> Self:
Expand All @@ -138,8 +128,6 @@ def addon(self, data: Iterable[str | Any]) -> Self:
self.raw_data.append(d)
self.ndata += 1
self.bak_data = self.raw_data.copy()
if self.message_cache:
self.token = self.generate_token(self.raw_data)
return self

def next(self, separate: tuple[str, ...] | None = None, move: bool = True) -> tuple[str | Any, bool]:
Expand Down
Loading

0 comments on commit b8bc532

Please sign in to comment.