forked from szilard/benchm-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
5xa-spark-1hot--sp20.txt
40 lines (22 loc) · 1.32 KB
/
5xa-spark-1hot--sp20.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
spark-2.0.0-bin-hadoop2.7/bin/spark-shell --driver-memory 10G --executor-memory 10G --packages com.databricks:spark-csv_2.11:1.5.0
// from Joseph Bradley https://gist.github.com/jkbradley/1e3cc0b3116f2f615b3f
// modifications by Xusen Yin https://github.com/szilard/benchm-ml/commit/db65cf000c9b1565b6e93d2d10c92dd646644d85
// some changes by @szilard for Spark 2.0
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
val loader = spark.read.format("com.databricks.spark.csv").option("header", "true")
val trainDF = loader.load("train-1m.csv")
val testDF = loader.load("test.csv")
val fullDF0 = trainDF.withColumn("isTrain", lit(true)).unionAll(testDF.withColumn("isTrain", lit(false)))
val fullDF = fullDF0.withColumn("DepTime", col("DepTime").cast(DoubleType)).withColumn("Distance", col("Distance").cast(DoubleType))
fullDF.printSchema
fullDF.show(5)
val res = new RFormula().setFormula("dep_delayed_15min ~ .").fit(fullDF).transform(fullDF)
res.printSchema
res.show(5)
val finalTrainDF = res.where(col("isTrain"))
val finalTestDF = res.where(!col("isTrain"))
finalTrainDF.write.mode("overwrite").parquet("spark1hot-train-1m.parquet")
finalTestDF.write.mode("overwrite").parquet("spark1hot-test-1m.parquet")