From 1382193b56453d11189f9dfe26db832cf73f0189 Mon Sep 17 00:00:00 2001 From: Jose Venegas <126916083+sfc-gh-jvaenegasvega@users.noreply.github.com> Date: Mon, 24 Apr 2023 09:08:41 -0600 Subject: [PATCH 1/2] Updating data frame reader extensions to support csv function --- .../dataframe_reader_extensions.py | 64 ++++++++++++++----- tests/test_dataframe_reader_extensions.py | 27 +++++--- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/snowpark_extensions/dataframe_reader_extensions.py b/snowpark_extensions/dataframe_reader_extensions.py index 994076d..44a950b 100644 --- a/snowpark_extensions/dataframe_reader_extensions.py +++ b/snowpark_extensions/dataframe_reader_extensions.py @@ -9,6 +9,8 @@ import logging DataFrameReader.___extended = True DataFrameReader.__option = DataFrameReader.option + DataFrameReader.__csv = DataFrameReader.csv + def _option(self, key: str, value: Any) -> "DataFrameReader": key = key.upper() if key == "SEP" or key == "DELIMITER": @@ -42,11 +44,7 @@ def _load(self,path: Union[str, List[str], None] = None, format: Optional[str] = self.format(format) if schema: self.schema(schema) - files = [] - if isinstance(path,list): - files.extend(path) - else: - files.append(path) + files = get_file_paths(path) session = context.get_active_session() if stage is None: stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEMP_STAGE")}' @@ -54,16 +52,9 @@ def _load(self,path: Union[str, List[str], None] = None, format: Optional[str] = stage_files = [x for x in path if x.startswith("@")] if len(stage_files) > 1: raise Exception("Currently only one staged file can be specified. You can use a pattern if you want to specify several files") - print(f"Uploading files using stage {stage}") - for file in files: - if file.startswith("file://"): # upload local file - session.file.put(file,stage) - elif file.startswith("@"): #ignore it is on an stage - return self._read_semi_structured_file(file,format) - else: #assume it is file too - session.file.put(f"file://{file}",f"@{stage}") + stage = get_stage(self, session, files, stage) if self._file_type == "csv": - return self.csv(f"@{stage}") + return self.__csv(f"@{stage}") return self._read_semi_structured_file(f"@{stage}",format) def _format(self, file_type: str) -> "DataFrameReader": @@ -72,7 +63,50 @@ def _format(self, file_type: str) -> "DataFrameReader": self._file_type = file_type else: raise Exception(f"Unsupported file format {file_type}") + + def _csv(self,path: Union[str, List[str]],schema: Optional[Union[StructType, str]] = None,sep: Optional[str] = None,encoding: Optional[str] = None,quote: Optional[str] = None, + escape: Optional[str] = None,comment: Optional[str] = None,header: Optional[Union[bool, str]] = None,inferSchema: Optional[Union[bool, str]] = None, + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,nullValue: Optional[str] = None, + nanValue: Optional[str] = None,positiveInf: Optional[str] = None,negativeInf: Optional[str] = None,dateFormat: Optional[str] = None,timestampFormat: Optional[str] = None, + maxColumns: Optional[Union[int, str]] = None,maxCharsPerColumn: Optional[Union[int, str]] = None,maxMalformedLogPerPartition: Optional[Union[int, str]] = None, + mode: Optional[str] = None,columnNameOfCorruptRecord: Optional[str] = None,multiLine: Optional[Union[bool, str]] = None,charToEscapeQuoteEscaping: Optional[str] = None, + samplingRatio: Optional[Union[float, str]] = None,enforceSchema: Optional[Union[bool, str]] = None,emptyValue: Optional[str] = None,locale: Optional[str] = None, + lineSep: Optional[str] = None,pathGlobFilter: Optional[Union[bool, str]] = None,recursiveFileLookup: Optional[Union[bool, str]] = None,modifiedBefore: Optional[Union[bool, str]] = None, + modifiedAfter: Optional[Union[bool, str]] = None,unescapedQuoteHandling: Optional[str] = None) -> "DataFrame": + params = {k: v for k, v in locals().items() if v is not None} + params.pop("self", None) + params.pop("path", None) + params.pop("schema", None) + if schema: + self.schema(schema) + files = get_file_paths(path) + session = context.get_active_session() + stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEMP_STAGE")}' + session.sql(f'create TEMPORARY stage if not exists {stage}').show() + stage = get_stage(self, session, files, stage) + for key, value in params.items(): + self = self.option(key, value) + return self.__csv(f"@{stage}") + def get_file_paths(path: Union[str, List[str]]): + if isinstance(path,list): + return path + else: + return [path] + + def get_stage(self, session, files: List[str], stage: str): + print(f"Uploading files using stage {stage}") + for file in files: + if file.startswith("file://"): # upload local file + session.file.put(file,stage) + elif file.startswith("@"): #ignore it is on an stage + return self._read_semi_structured_file(file,format) + else: #assume it is file too + session.file.put(f"file://{file}",f"@{stage}") + return stage + DataFrameReader.format = _format DataFrameReader.load = _load - DataFrameReader.option = _option \ No newline at end of file + DataFrameReader.option = _option + DataFrameReader.csv = _csv + \ No newline at end of file diff --git a/tests/test_dataframe_reader_extensions.py b/tests/test_dataframe_reader_extensions.py index 1458a49..731a88c 100644 --- a/tests/test_dataframe_reader_extensions.py +++ b/tests/test_dataframe_reader_extensions.py @@ -1,10 +1,26 @@ import pytest -from snowflake.snowpark import Session, Row +from snowflake.snowpark import Session, Row, DataFrameReader from snowflake.snowpark.types import * import snowpark_extensions def test_load(): session = Session.builder.from_snowsql().getOrCreate() + cases = session.read.load(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], + schema=get_schema(), + format="csv", + sep=",", + header="true") + assert 10 == len(cases.collect()) + +def test_csv(): + session = Session.builder.from_snowsql().getOrCreate() + csvInfo = session.read.csv(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], + schema=get_schema(), + sep=",", + header="true") + assert 10 == len(csvInfo.collect()) + +def get_schema(): schema = StructType([ \ StructField("case_id", StringType()), \ StructField("province", StringType()), \ @@ -13,11 +29,6 @@ def test_load(): StructField("infection_case",StringType()), \ StructField("confirmed", IntegerType()), \ StructField("latitude", FloatType()), \ - StructField("cilongitudety", FloatType()) \ + StructField("longitude", FloatType()) \ ]) - cases = session.read.load(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], - schema=schema, - format="csv", - sep=",", - header="true") - assert 10 == len(cases.collect()) \ No newline at end of file + return schema \ No newline at end of file From 54c03a5e0a1afccb3bcc2155faf2ae5a676f6e8a Mon Sep 17 00:00:00 2001 From: Jose Venegas <126916083+sfc-gh-jvaenegasvega@users.noreply.github.com> Date: Tue, 25 Apr 2023 10:41:26 -0600 Subject: [PATCH 2/2] Updating csv extension and test --- .../dataframe_reader_extensions.py | 86 +++++++++++-------- tests/test_dataframe_reader_extensions.py | 14 ++- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/snowpark_extensions/dataframe_reader_extensions.py b/snowpark_extensions/dataframe_reader_extensions.py index 44a950b..b189bc4 100644 --- a/snowpark_extensions/dataframe_reader_extensions.py +++ b/snowpark_extensions/dataframe_reader_extensions.py @@ -44,7 +44,11 @@ def _load(self,path: Union[str, List[str], None] = None, format: Optional[str] = self.format(format) if schema: self.schema(schema) - files = get_file_paths(path) + files = [] + if isinstance(path,list): + files.extend(path) + else: + files.append(path) session = context.get_active_session() if stage is None: stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEMP_STAGE")}' @@ -52,9 +56,16 @@ def _load(self,path: Union[str, List[str], None] = None, format: Optional[str] = stage_files = [x for x in path if x.startswith("@")] if len(stage_files) > 1: raise Exception("Currently only one staged file can be specified. You can use a pattern if you want to specify several files") - stage = get_stage(self, session, files, stage) + print(f"Uploading files using stage {stage}") + for file in files: + if file.startswith("file://"): # upload local file + session.file.put(file,stage) + elif file.startswith("@"): #ignore it is on an stage + return self._read_semi_structured_file(file,format) + else: #assume it is file too + session.file.put(f"file://{file}",f"@{stage}") if self._file_type == "csv": - return self.__csv(f"@{stage}") + return self.csv(f"@{stage}") return self._read_semi_structured_file(f"@{stage}",format) def _format(self, file_type: str) -> "DataFrameReader": @@ -64,46 +75,51 @@ def _format(self, file_type: str) -> "DataFrameReader": else: raise Exception(f"Unsupported file format {file_type}") - def _csv(self,path: Union[str, List[str]],schema: Optional[Union[StructType, str]] = None,sep: Optional[str] = None,encoding: Optional[str] = None,quote: Optional[str] = None, - escape: Optional[str] = None,comment: Optional[str] = None,header: Optional[Union[bool, str]] = None,inferSchema: Optional[Union[bool, str]] = None, - ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,nullValue: Optional[str] = None, - nanValue: Optional[str] = None,positiveInf: Optional[str] = None,negativeInf: Optional[str] = None,dateFormat: Optional[str] = None,timestampFormat: Optional[str] = None, - maxColumns: Optional[Union[int, str]] = None,maxCharsPerColumn: Optional[Union[int, str]] = None,maxMalformedLogPerPartition: Optional[Union[int, str]] = None, - mode: Optional[str] = None,columnNameOfCorruptRecord: Optional[str] = None,multiLine: Optional[Union[bool, str]] = None,charToEscapeQuoteEscaping: Optional[str] = None, - samplingRatio: Optional[Union[float, str]] = None,enforceSchema: Optional[Union[bool, str]] = None,emptyValue: Optional[str] = None,locale: Optional[str] = None, - lineSep: Optional[str] = None,pathGlobFilter: Optional[Union[bool, str]] = None,recursiveFileLookup: Optional[Union[bool, str]] = None,modifiedBefore: Optional[Union[bool, str]] = None, - modifiedAfter: Optional[Union[bool, str]] = None,unescapedQuoteHandling: Optional[str] = None) -> "DataFrame": + def _csv(self, + path: str, + schema: Optional[Union[StructType, str]] = None, + sep: Optional[str] = None, + encoding: Optional[str] = None, + quote: Optional[str] = None, + escape: Optional[str] = None, + comment: Optional[str] = None, + header: Optional[Union[bool, str]] = None, + inferSchema: Optional[Union[bool, str]] = None, + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None, + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None, + nullValue: Optional[str] = None, + nanValue: Optional[str] = None, + positiveInf: Optional[str] = None, + negativeInf: Optional[str] = None, + dateFormat: Optional[str] = None, + timestampFormat: Optional[str] = None, + maxColumns: Optional[Union[int, str]] = None, + maxCharsPerColumn: Optional[Union[int, str]] = None, + maxMalformedLogPerPartition: Optional[Union[int, str]] = None, + mode: Optional[str] = None, + columnNameOfCorruptRecord: Optional[str] = None, + multiLine: Optional[Union[bool, str]] = None, + charToEscapeQuoteEscaping: Optional[str] = None, + samplingRatio: Optional[Union[float, str]] = None, + enforceSchema: Optional[Union[bool, str]] = None, + emptyValue: Optional[str] = None, + locale: Optional[str] = None, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + modifiedBefore: Optional[Union[bool, str]] = None, + modifiedAfter: Optional[Union[bool, str]] = None, + unescapedQuoteHandling: Optional[str] = None + ) -> "DataFrame": params = {k: v for k, v in locals().items() if v is not None} params.pop("self", None) params.pop("path", None) params.pop("schema", None) if schema: self.schema(schema) - files = get_file_paths(path) - session = context.get_active_session() - stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEMP_STAGE")}' - session.sql(f'create TEMPORARY stage if not exists {stage}').show() - stage = get_stage(self, session, files, stage) for key, value in params.items(): self = self.option(key, value) - return self.__csv(f"@{stage}") - - def get_file_paths(path: Union[str, List[str]]): - if isinstance(path,list): - return path - else: - return [path] - - def get_stage(self, session, files: List[str], stage: str): - print(f"Uploading files using stage {stage}") - for file in files: - if file.startswith("file://"): # upload local file - session.file.put(file,stage) - elif file.startswith("@"): #ignore it is on an stage - return self._read_semi_structured_file(file,format) - else: #assume it is file too - session.file.put(f"file://{file}",f"@{stage}") - return stage + return self.__csv(path) DataFrameReader.format = _format DataFrameReader.load = _load diff --git a/tests/test_dataframe_reader_extensions.py b/tests/test_dataframe_reader_extensions.py index 731a88c..7a9c2bc 100644 --- a/tests/test_dataframe_reader_extensions.py +++ b/tests/test_dataframe_reader_extensions.py @@ -1,6 +1,7 @@ import pytest from snowflake.snowpark import Session, Row, DataFrameReader from snowflake.snowpark.types import * +from snowflake.snowpark.dataframe import _generate_prefix import snowpark_extensions def test_load(): @@ -14,11 +15,18 @@ def test_load(): def test_csv(): session = Session.builder.from_snowsql().getOrCreate() - csvInfo = session.read.csv(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], + stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEST_STAGE")}' + session.sql(f'CREATE TEMPORARY STAGE IF NOT EXISTS {stage}').show() + session.file.put(f"file://./tests/data/test1_0.csv", f"@{stage}") + session.file.put(f"file://./tests/data/test1_1.csv", f"@{stage}") + dfReader = session.read + csvInfo = dfReader.csv(f"@{stage}", schema=get_schema(), sep=",", header="true") - assert 10 == len(csvInfo.collect()) + assert 10 == len(csvInfo.collect()) + assert dfReader._cur_options["FIELD_DELIMITER"] == "," + assert dfReader._cur_options["SKIP_HEADER"] == 1 def get_schema(): schema = StructType([ \ @@ -29,6 +37,6 @@ def get_schema(): StructField("infection_case",StringType()), \ StructField("confirmed", IntegerType()), \ StructField("latitude", FloatType()), \ - StructField("longitude", FloatType()) \ + StructField("longitude", FloatType()) \ ]) return schema \ No newline at end of file