Skip to content

Commit

Permalink
Add URLMatcher.match_universal().
Browse files Browse the repository at this point in the history
  • Loading branch information
wRAR committed Apr 11, 2024
1 parent 4bd19b7 commit d3d9815
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
13 changes: 11 additions & 2 deletions tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_match_all():
assert list(matcher.match_all("http://bar.example.com/products")) == [3, 4, 1]


def test_include_universal():
def test_match_all_include_universal():
matcher = URLMatcher()
matcher.add_or_update(1, Patterns(include=["example.com"]))
matcher.add_or_update(2, Patterns(include=[]))
Expand All @@ -174,4 +174,13 @@ def test_include_universal():
assert list(matcher.match_all("http://foo.example.com")) == [3, 1, 4, 2]
assert list(matcher.match_all("http://foo.example.com", include_universal=False)) == [3, 1]
assert list(matcher.match_all("http://example.net")) == [4, 2]
assert list(matcher.match_all("http://example.net", include_universal=False)) == [4, 2]
assert list(matcher.match_all("http://example.net", include_universal=False)) == []


def test_match_universal():
matcher = URLMatcher()
matcher.add_or_update(1, Patterns(include=["example.com"]))
matcher.add_or_update(2, Patterns(include=[]))
matcher.add_or_update(3, Patterns(include=["foo.example.com"]))
matcher.add_or_update(4, Patterns(include=[""]))
assert list(matcher.match_universal()) == [4, 2]
25 changes: 16 additions & 9 deletions url_matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from dataclasses import dataclass, field
from itertools import chain
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, Union

from url_matcher.patterns import PatternMatcher, get_pattern_domain, hierarchical_str
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(self, data: Union[Mapping[Any, Patterns], Iterable[Tuple[Any, Patte
initialize the object from
"""
self.matchers_by_domain: Dict[str, List[PatternsMatcher]] = {}
self.matchers_universal: List[PatternsMatcher] = []
self.patterns: Dict[Any, Patterns] = {}

if data:
Expand Down Expand Up @@ -155,17 +157,15 @@ def match(self, url: str, *, include_universal=True) -> Optional[Any]:

def match_all(self, url: str, *, include_universal=True) -> Iterator[Any]:
domain = get_domain(url)
domain_matchers = self.matchers_by_domain.get(domain) or []
domain_match = False
for matcher in domain_matchers:
matchers: Iterable[PatternsMatcher] = self.matchers_by_domain.get(domain) or []
if include_universal:
matchers = chain(matchers, self.matchers_universal)
for matcher in matchers:
if matcher.match(url):
domain_match = True
yield matcher.identifier
if include_universal or not domain_match:
universal_matchers = self.matchers_by_domain.get("") or []
for matcher in universal_matchers:
if matcher.match(url):
yield matcher.identifier

def match_universal(self) -> Iterator[Any]:
return (m.identifier for m in self.matchers_universal)

def _sort_domain(self, domain: str):
"""
Expand All @@ -186,6 +186,7 @@ def sort_key(matcher: PatternsMatcher) -> Tuple:
return (matcher.patterns.priority, sorted_includes, matcher.identifier)

self.matchers_by_domain[domain].sort(key=sort_key, reverse=True)
self.matchers_universal.sort(key=sort_key, reverse=True)

def _del_matcher(self, domain: str, identifier: Any):
matchers = self.matchers_by_domain[domain]
Expand All @@ -195,10 +196,16 @@ def _del_matcher(self, domain: str, identifier: Any):
break
if not matchers:
del self.matchers_by_domain[domain]
for idx in range(len(self.matchers_universal)):
if self.matchers_universal[idx].identifier == identifier:
del self.matchers_universal[idx]
break

def _add_matcher(self, domain: str, matcher: PatternsMatcher):
# FIXME: This can be made much more efficient if we insert the data directly in order instead of resorting.
# The bisect module could be used for this purpose.
# I'm leaving it for the future as insertion time is not critical.
self.matchers_by_domain.setdefault(domain, []).append(matcher)
if domain == "":
self.matchers_universal.append(matcher)
self._sort_domain(domain)

0 comments on commit d3d9815

Please sign in to comment.