Skip to content

Commit

Permalink
Spark impl
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym committed Oct 12, 2023
1 parent e7dab0c commit b2f5aac
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
21 changes: 17 additions & 4 deletions examples/restaurant_visits/run_on_dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

from absl import app
from absl import flags
import pipeline_dp
from pipeline_dp import dataframes
import os
import shutil
import pandas as pd
from pyspark.sql import SparkSession
import pyspark

import pipeline_dp
from pipeline_dp import dataframes

FLAGS = flags.FLAGS
flags.DEFINE_string('input_file', 'restaurants_week_data.csv',
'The file with the restaraunt visits data')
Expand All @@ -33,6 +36,14 @@
'Which dataframes to use.')


def delete_if_exists(filename):
if os.path.exists(filename):
if os.path.isdir(filename):
shutil.rmtree(filename)
else:
os.remove(filename)


def load_data_in_pandas_dataframe() -> pd.DataFrame:
df = pd.read_csv(FLAGS.input_file)
df.rename(inplace=True,
Expand All @@ -48,7 +59,7 @@ def load_data_in_pandas_dataframe() -> pd.DataFrame:

def load_data_in_spark_dataframe(
spark: SparkSession) -> pyspark.sql.dataframe.DataFrame:
df = spark.read.csv(FLAGS.input_file, header=True)
df = spark.read.csv(FLAGS.input_file, header=True, inferSchema=True)
return df.withColumnRenamed('VisitorId', 'visitor_id').withColumnRenamed(
'Time entered', 'enter_time').withColumnRenamed(
'Time spent (minutes)', 'spent_minutes').withColumnRenamed(
Expand Down Expand Up @@ -80,7 +91,9 @@ def compute_on_spark_dataframes() -> None:
df = load_data_in_spark_dataframe(spark)
df.printSchema()
result_df = compute_private_result(df)
result_df.write.format("csv").save(FLAGS.output_file)
result_df.printSchema()
delete_if_exists(FLAGS.output_file)
result_df.write.format("csv").option("header", True).save(FLAGS.output_file)


def main(unused_argv):
Expand Down
37 changes: 25 additions & 12 deletions pipeline_dp/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

import pipeline_dp
from typing import Any, Sequence, Callable, Optional, List, Dict
from typing import Optional, List, Dict, Iterable
import pyspark


Expand Down Expand Up @@ -52,9 +52,9 @@ def dataframe_to_collection(self, df: pd.DataFrame,
# name=None makes that tuples instead of name tuple are returned.
return list(df.itertuples(index=False, name=None))

def collection_to_dataframe(self, col: list,
def collection_to_dataframe(self, col: Iterable,
partition_key_column: str) -> pd.DataFrame:
assert isinstance(col, list), "Only local run is supported for now"
assert isinstance(col, Iterable), "Only local run is supported for now"
partition_keys, data = list(zip(*col))
df = pd.DataFrame(data=data)
df[partition_key_column] = partition_keys
Expand All @@ -66,33 +66,47 @@ def collection_to_dataframe(self, col: list,

class SparkConverter(DataFrameConvertor):

def __init__(self, spark):
self._spark = spark

def dataframe_to_collection(self, df, columns: _Columns) -> pyspark.RDD:
columns_to_keep = [columns.privacy_key, columns.partition_key]
if columns.value is not None:
columns_to_keep.append(columns.value)
df = df[columns_to_keep] # leave only needed columns.
return []
rdd = df.rdd.map(lambda row: (row[0], row[1], row[2]))
return rdd

def collection_to_dataframe(self, col: pyspark.RDD,
partition_key_column: str):
pass
def collection_to_dataframe(
self, col: pyspark.RDD,
partition_key_column: str) -> pyspark.sql.dataframe.DataFrame:

def convert_to_dict(row):
partition_key, metrics = row
result = {partition_key_column: partition_key}
result.update(metrics._asdict())
return result

col = col.map(convert_to_dict)
df = self._spark.createDataFrame(col)
return df


def create_backend_for_dataframe(
df) -> pipeline_dp.pipeline_backend.PipelineBackend:
if isinstance(df, pd.DataFrame):
return pipeline_dp.LocalBackend()
if isinstance(df, pyspark.DataFrame):
return pipeline_dp.SparkRDDBackend()
if isinstance(df, pyspark.sql.dataframe.DataFrame):
return pipeline_dp.SparkRDDBackend(df.sparkSession.sparkContext)
raise NotImplementedError(
f"Dataframes of type {type(df)} not yet supported")


def create_dataframe_converter(df) -> DataFrameConvertor:
if isinstance(df, pd.DataFrame):
return PandasConverter()
if isinstance(df, pyspark.DataFrame):
return SparkConverter()
if isinstance(df, pyspark.sql.dataframe.DataFrame):
return SparkConverter(df.sparkSession)
raise NotImplementedError(
f"Dataframes of type {type(df)} not yet supported")

Expand Down Expand Up @@ -151,7 +165,6 @@ def run_query(self,
public_partitions=self._public_partitions,
out_explain_computation_report=explain_computation_report)
budget_accountant.compute_budgets()
dp_result = list(dp_result)
self._expain_computation_report = explain_computation_report.text()
return converter.collection_to_dataframe(dp_result,
self._columns.partition_key)
Expand Down

0 comments on commit b2f5aac

Please sign in to comment.