forked from yt-dlp/yt-dlp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
traversal.py
460 lines (361 loc) · 17.2 KB
/
traversal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
from __future__ import annotations
import collections
import collections.abc
import contextlib
import functools
import http.cookies
import inspect
import itertools
import re
import typing
import xml.etree.ElementTree
from ._utils import (
IDENTITY,
NO_DEFAULT,
ExtractorError,
LazyList,
deprecation_warning,
get_elements_html_by_class,
get_elements_html_by_attribute,
get_elements_by_attribute,
get_element_html_by_attribute,
get_element_by_attribute,
get_element_html_by_id,
get_element_by_id,
get_element_html_by_class,
get_elements_by_class,
get_element_text_and_html_by_tag,
is_iterable_like,
try_call,
url_or_none,
variadic,
)
def traverse_obj(
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
"""
Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
'value'
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
`xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of:
- `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{type, type, ...}`/`{func}`. If a `type`, return only
values of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values.
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`.
For `Iterable`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict`: Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching.
- `filter`-builtin: Return the value if it is truthy, `None` otherwise.
`tuple`, `list`, and `dict` all support nested paths and branches.
@params paths Paths by which to traverse.
@param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
`traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
@param traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead.
If no `default` is given and the last path branches, a `list` of results
is always returned. If a path ends on a `dict` that result will always be a `dict`.
"""
if is_user_input is not NO_DEFAULT:
deprecation_warning('The is_user_input parameter is deprecated and no longer works')
casefold = lambda k: k.casefold() if isinstance(k, str) else k
if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def apply_key(key, obj, is_last):
branching = False
result = None
if obj is None and traverse_string:
if key is ... or callable(key) or isinstance(key, slice):
branching = True
result = ()
elif key is None:
result = obj
elif isinstance(key, set):
item = next(iter(key))
if len(key) > 1 or isinstance(item, type):
assert all(isinstance(item, type) for item in key)
if isinstance(obj, tuple(key)):
result = obj
else:
result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)):
branching = True
result = itertools.chain.from_iterable(
apply_path(obj, branch, is_last)[0] for branch in key)
elif key is ...:
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
result = obj
elif isinstance(obj, re.Match):
result = obj.groups()
elif traverse_string:
branching = False
result = str(obj)
else:
result = ()
elif callable(key):
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
enumerate((obj.group(), *obj.groups())),
obj.groupdict().items())
elif traverse_string:
branching = False
iter_obj = enumerate(str(obj))
else:
iter_obj = ()
result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
result = {
k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not NO_DEFAULT
} or None
elif isinstance(obj, collections.abc.Mapping):
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match):
if isinstance(key, int) or casesense:
with contextlib.suppress(IndexError):
result = obj.group(key)
elif isinstance(key, str):
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)):
if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]
elif traverse_string:
with contextlib.suppress(IndexError):
result = str(obj)[key]
elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
xpath, _, special = key.rpartition('/')
if not special.startswith('@') and not special.endswith('()'):
xpath = key
special = None
# Allow abbreviations of relative paths, absolute paths error
if xpath.startswith('/'):
xpath = f'.{xpath}'
elif xpath and not xpath.startswith('./'):
xpath = f'./{xpath}'
def apply_specials(element):
if special is None:
return element
if special == '@':
return element.attrib
if special.startswith('@'):
return try_call(element.attrib.get, args=(special[1:],))
if special == 'text()':
return element.text
raise SyntaxError(f'apply_specials is missing case for {special!r}')
if xpath:
result = list(map(apply_specials, obj.iterfind(xpath)))
else:
result = apply_specials(obj)
return branching, result if branching else (result,)
def lazy_last(iterable):
iterator = iter(iterable)
prev = next(iterator, NO_DEFAULT)
if prev is NO_DEFAULT:
return
for item in iterator:
yield False, prev
prev = item
yield True, prev
def apply_path(start_obj, path, test_type):
objs = (start_obj,)
has_branched = False
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if not casesense and isinstance(key, str):
key = key.casefold()
if key in (any, all):
has_branched = False
filtered_objs = (obj for obj in objs if obj not in (None, {}))
if key is any:
objs = (next(filtered_objs, None),)
else:
objs = (list(filtered_objs),)
continue
if key is filter:
objs = filter(None, objs)
continue
if __debug__ and callable(key):
# Verify function signature
inspect.signature(key).bind(None, None)
new_objs = []
for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
objs = itertools.chain.from_iterable(new_objs)
if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs)
return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(item for item in results if item not in (None, {}))
if get_all and has_branched:
if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1):
is_last = index == len(paths)
try:
result = _traverse_obj(obj, path, is_last, True)
if result is not None:
return result
except _RequiredError as e:
if is_last:
# Reraise to get cleaner stack trace
raise ExtractorError(e.orig_msg, expected=e.expected) from None
return None if default is NO_DEFAULT else default
def value(value, /):
return lambda _: value
def require(name, /, *, expected=False):
def func(value):
if value is None:
raise _RequiredError(f'Unable to extract {name}', expected=expected)
return value
return func
class _RequiredError(ExtractorError):
pass
@typing.overload
def subs_list_to_dict(*, ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
@typing.overload
def subs_list_to_dict(subs: list[dict] | None, /, *, ext: str | None = None) -> dict[str, list[dict]]: ...
def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None):
"""
Convert subtitles from a traversal into a subtitle dict.
The path should have an `all` immediately before this function.
Arguments:
`ext` The default value for `ext` in the subtitle dict
In the dict you can set the following additional items:
`id` The subtitle id to sort the dict into
`quality` The sort order for each subtitle
"""
if subs is None:
return functools.partial(subs_list_to_dict, ext=ext)
result = collections.defaultdict(list)
for sub in subs:
if not url_or_none(sub.get('url')) and not sub.get('data'):
continue
sub_id = sub.pop('id', None)
if sub_id is None:
continue
if ext is not None and not sub.get('ext'):
sub['ext'] = ext
result[sub_id].append(sub)
result = dict(result)
for subs in result.values():
subs.sort(key=lambda x: x.pop('quality', 0) or 0)
return result
@typing.overload
def find_element(*, attr: str, value: str, tag: str | None = None, html=False): ...
@typing.overload
def find_element(*, cls: str, html=False): ...
@typing.overload
def find_element(*, id: str, tag: str | None = None, html=False): ...
@typing.overload
def find_element(*, tag: str, html=False): ...
def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False):
# deliberately using `id=` and `cls=` for ease of readability
assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
ANY_TAG = r'[\w:.-]+'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
assert not id, 'Cannot match both attr and id'
func = get_element_html_by_attribute if html else get_element_by_attribute
return functools.partial(func, attr, value, tag=tag or ANY_TAG)
elif cls:
assert not id, 'Cannot match both cls and id'
assert tag is None, 'Cannot match both cls and tag'
func = get_element_html_by_class if html else get_elements_by_class
return functools.partial(func, cls)
elif id:
func = get_element_html_by_id if html else get_element_by_id
return functools.partial(func, id, tag=tag or ANY_TAG)
index = int(bool(html))
return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
@typing.overload
def find_elements(*, cls: str, html=False): ...
@typing.overload
def find_elements(*, attr: str, value: str, tag: str | None = None, html=False): ...
def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False):
# deliberately using `cls=` for ease of readability
assert cls or (attr and value), 'One of cls or (attr AND value) is required'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
func = get_elements_html_by_attribute if html else get_elements_by_attribute
return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+')
assert not tag, 'Cannot match both cls and tag'
func = get_elements_html_by_class if html else get_elements_by_class
return functools.partial(func, cls)
def trim_str(*, start=None, end=None):
def trim(s):
if s is None:
return None
start_idx = 0
if start and s.startswith(start):
start_idx = len(start)
if end and s.endswith(end):
return s[start_idx:-len(end)]
return s[start_idx:]
return trim
def get_first(obj, *paths, **kwargs):
return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
for val in map(d.get, variadic(key_or_keys)):
if val is not None and (val or not skip_false_values):
return val
return default