diff --git a/dags/violation_detection_helpers/transform.py b/dags/violation_detection_helpers/transform.py index 5fe80949..00297621 100644 --- a/dags/violation_detection_helpers/transform.py +++ b/dags/violation_detection_helpers/transform.py @@ -6,19 +6,31 @@ class TransformPlatformRawData: - def __init__(self) -> None: + def __init__(self, cursor_batch_size: int = 10) -> None: + """ + Transformation of the raw data by classifying the violation of them + + Parameters + ------------ + cursor_batch_size : int + the pymongo cursor batch size + lowering it could increase the IO (requests to mongo) + increasing it could reduce IO but increase the idle time of the mongo cursor + default is 10 document per batch in each cursor + """ + self.batch_size = cursor_batch_size self.classifier = Classifier() def transform( self, - raw_data: Cursor, + data_cursor: Cursor, ) -> list[dict]: """ transform a list of platform's `rawmemberactivities` by labeling them Parameters ------------- - raw_data : Cursor + data_cursor : Cursor the data cursor to be transformed (using cursor for more efficiency of database) the transformation here is to label the violation for texts @@ -28,13 +40,15 @@ def transform( labeled_data : list[dict] the same data but with a label for violation detection """ + data_cursor = data_cursor.hint({"$natural": 1}).batch_size(self.batch_size) + labeled_data = [] # caching label per source_id # since we might have multiple document with same text cached_label: dict[str, str] = {} - for record in raw_data: + for record in data_cursor: try: data = copy.deepcopy(record)