diff --git a/src/wikitextprocessor/parser.py b/src/wikitextprocessor/parser.py index 5f3158dc..df1624d0 100644 --- a/src/wikitextprocessor/parser.py +++ b/src/wikitextprocessor/parser.py @@ -496,7 +496,7 @@ def filter_empty_str_child(self) -> Iterator[Union[str, "WikiNode"]]: @overload def find_html( self, - target_tag: str, + target_tags: str | list[str], with_index: Literal[True], attr_name: str, attr_value: str, @@ -505,7 +505,7 @@ def find_html( @overload def find_html( self, - target_tag: str, + target_tags: str | list[str], with_index: Literal[False] = ..., attr_name: str = ..., attr_value: str = ..., @@ -513,7 +513,7 @@ def find_html( def find_html( self, - target_tag: str, + target_tags: str | list[str], with_index: bool = False, attr_name: str = "", attr_value: str = "", @@ -523,7 +523,9 @@ def find_html( if TYPE_CHECKING: assert isinstance(node, HTMLNode) # node.tag is an alias for node.sarg defined in HTMLNode - if node.tag == target_tag: + if isinstance(target_tags, str): + target_tags = [target_tags] + if node.tag in target_tags: if len(attr_name) > 0 and attr_value not in node.attrs.get( attr_name, {} ):