-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init python_lambda_mapper * set default arg * fix init * support batched & add docs * fix docs * Quick fix for some minor problems (#503) * * remove str conversion for fps para of add_stream func + add requires from librosa to avoid lazy_loader failure during multiprocessing * * remove str conversion for fps para of add_stream func + add requires from librosa to avoid lazy_loader failure during multiprocessing * * install cmake before * * install cmake before * * install cmake before * * update unit test tags * * update unit test tags * * update unit test tags * * update unit test tags * * try to remove samplerate dep * * skip audio duration and audio nmf snr filters * * skip video_tagging_from_frames_filter * * skip video_tagging_from_audios_filter * * skip video_motion_score_raft_filter * fix batch bug (#504) * fix batch bug * fix filter batch * not rank for filter * limit pyav version --------- Co-authored-by: Yilun Huang <[email protected]> Co-authored-by: BeachWang <[email protected]>
- Loading branch information
1 parent
5a4b1a1
commit 0fe505e
Showing
6 changed files
with
163 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import ast | ||
|
||
from ..base_op import OPERATORS, Mapper | ||
|
||
OP_NAME = 'python_lambda_mapper' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class PythonLambdaMapper(Mapper): | ||
"""Mapper for executing Python lambda function on data samples.""" | ||
|
||
def __init__(self, lambda_str: str = '', batched: bool = False, **kwargs): | ||
""" | ||
Initialization method. | ||
:param lambda_str: A string representation of the lambda function to be | ||
executed on data samples. If empty, the identity function is used. | ||
:param batched: A boolean indicating whether to process input data in | ||
batches. | ||
:param kwargs: Additional keyword arguments passed to the parent class. | ||
""" | ||
self._batched_op = bool(batched) | ||
super().__init__(**kwargs) | ||
|
||
# Parse and validate the lambda function | ||
if not lambda_str: | ||
self.lambda_func = lambda sample: sample | ||
else: | ||
self.lambda_func = self._create_lambda(lambda_str) | ||
|
||
def _create_lambda(self, lambda_str: str): | ||
# Parse input string into an AST and check for a valid lambda function | ||
try: | ||
node = ast.parse(lambda_str, mode='eval') | ||
|
||
# Check if the body of the expression is a lambda | ||
if not isinstance(node.body, ast.Lambda): | ||
raise ValueError( | ||
'Input string must be a valid lambda function.') | ||
|
||
# Check that the lambda has exactly one argument | ||
if len(node.body.args.args) != 1: | ||
raise ValueError( | ||
'Lambda function must have exactly one argument.') | ||
|
||
# Compile the AST to code | ||
compiled_code = compile(node, '<string>', 'eval') | ||
# Safely evaluate the compiled code allowing built-in functions | ||
func = eval(compiled_code, {'__builtins__': __builtins__}) | ||
return func | ||
except Exception as e: | ||
raise ValueError(f'Invalid lambda function: {e}') | ||
|
||
def process_single(self, sample): | ||
# Process the input through the lambda function and return the result | ||
result = self.lambda_func(sample) | ||
|
||
# Check if the result is a valid | ||
if not isinstance(result, dict): | ||
raise ValueError(f'Lambda function must return a dictionary, ' | ||
f'got {type(result).__name__} instead.') | ||
|
||
return result | ||
|
||
def process_batched(self, samples): | ||
# Process the input through the lambda function and return the result | ||
result = self.lambda_func(samples) | ||
|
||
# Check if the result is a valid | ||
if not isinstance(result, dict): | ||
raise ValueError(f'Lambda function must return a dictionary, ' | ||
f'got {type(result).__name__} instead.') | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.