Spark Data Frames

Scala

Data Frames with Spark 2.0

Prerequisites:

  • spark 2.0.0 or higher, preferable with pre-built hadoop. Download link
  • scala 2.11.8 or higher. Download link

With the prerequisites satisfied, let's run the spark shell using the command spark-shell. The printout will be something as below:

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel).
16/12/25 12:52:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Spark context Web UI available at http://10.10.10.10:4040
Spark context available as 'sc' (master = local[*], app id = local-1482637944859).
Spark session available as 'spark'.
Welcome to
      ___               __
     / __/__  ___ ___ _/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.0.2
      /_/
         
Using Scala version 2.11.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_111)
Type in expressions to have them evaluated.
Type :help for more information.

scala>

The default settings for the log verbosity is WARN, so we will change that to ERROR in order to avoid screen clutter.
In the shell run the following command sc.setLogLevel("ERROR").

As shown in the previous message blob, the Spark context is available in the shell as sc, while the Spark session is called spark.

Let's load a CSV file from the local drive. Automatically Spark will transform it into a data frame, which we will cache for later use and faster access.

scala> val file_name = "/path/to/local/file.csv"

scala> val df = spark.read.option("header", "true").option("inferSchema", "true").csv(file_name)

scala> df.cache

The data that we just loaded are daily OHLCV values of a traded stock. Now we can perform some descriptive operations on the data set.

  • How to get the names of the columns
scala> df.columns
res6: Array[String] = Array(Date, Open, High, Low, Close, Volume)
  • How to list the first n rows. We chose n = 5 in this case.
scala> df.show(5)
+--------------------+-----+-----+-----+-----+-------+
|                Date| Open| High|  Low|Close| Volume|
+--------------------+-----+-----+-----+-----+-------+
|2006-01-03 00:00:...|490.0|493.8|481.1|492.9|1537660|
|2006-01-04 00:00:...|488.6|491.0|483.5|483.8|1871020|
|2006-01-05 00:00:...|484.4|487.8|484.0|486.2|1143160|
|2006-01-06 00:00:...|488.8|489.0|482.0|486.2|1370250|
|2006-01-09 00:00:...|486.0|487.4|483.0|483.9|1680740|
+--------------------+-----+-----+-----+-----+-------+
only showing top 5 rows
  • How to get statistical summary about the whole data set. Eg: number of entries (rows) or count, average value, min value, max value.
scala> df.describe().show
+-------+------------------+-----------------+------------------+------------------+-----------------+
|summary|              Open|             High|               Low|             Close|           Volume|
+-------+------------------+-----------------+------------------+------------------+-----------------+
|  count|               755|              755|               755|               755|              755|
|   mean| 386.0923178807949|390.6590596026489|380.80170860927143| 385.3421456953643|6308596.382781457|
| stddev|149.32301134820133|148.5151130063523|150.53136890891344|149.83310074439177| 8099892.56297633|
|    min|              54.4|             55.3|              30.5|              37.7|           632860|
|    max|             566.0|            570.0|             555.5|             564.1|        102869289|
+-------+------------------+-----------------+------------------+------------------+-----------------+
  • Select only some columns and show the first 10 entries
scala> df.select("Open", "Close").show(10)
+-----+-----+
| Open|Close|
+-----+-----+
|490.0|492.9|
|488.6|483.8|
|484.4|486.2|
|488.8|486.2|
|486.0|483.9|
|483.0|485.4|
|495.8|489.8|
|491.0|490.3|
|491.0|489.2|
|485.1|484.3|
+-----+-----+
only showing top 10 rows
  • Create a new data frame using the original data frame and add an extra column called H-L representing the difference between the High and the Low. All of the following commands are equivalent.
scala> val df2 = df.withColumn("H-L", $"High" - $"Low")
scala> val df2 = df.withColumn("H-L", df("High") - df("Low"))
  • checking the schema of the newly created data frame
scala> df2.printSchema
root
 |-- Date: timestamp (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)
 |-- Close: double (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- H-L: double (nullable = true)
  • Let's rename the column H-L into HighMinusLow. This renaming is not being done in place, since our data set is immutable. You could assign the result to another data set.
scala> df2.withColumnRenamed("H-L", "HighMinusLow").show(5)
+--------------------+-----+-----+-----+-----+-------+------------------+
|                Date| Open| High|  Low|Close| Volume|      HighMinusLow|
+--------------------+-----+-----+-----+-----+-------+------------------+
|2006-01-03 00:00:...|490.0|493.8|481.1|492.9|1537660|12.699999999999989|
|2006-01-04 00:00:...|488.6|491.0|483.5|483.8|1871020|               7.5|
|2006-01-05 00:00:...|484.4|487.8|484.0|486.2|1143160|3.8000000000000114|
|2006-01-06 00:00:...|488.8|489.0|482.0|486.2|1370250|               7.0|
|2006-01-09 00:00:...|486.0|487.4|483.0|483.9|1680740| 4.399999999999977|
+--------------------+-----+-----+-----+-----+-------+------------------+
only showing top 5 rows
  • Let's check how many entries with a High value over 550. The following methods are equivalent and of course the result will be the same.
scala> df.filter($"High" > 550).count()
scala> df.filter("High > 550").count()
  • Another similar way is to create a temporary SQL like database (let's call it tempDB) and perform regular SQL queries on it. This result will be the same as the one above.
scala> df.createOrReplaceTempView("tempDB")
scala> spark.sql("SELECT * FROM tempDB where High > 550").count()
  • A query with an equality condition using 3 different notations. All will return the same result. Please note the first notation, which is Scala specific.
df.filter($"High" === 540).show
df.filter("High = 540").show
spark.sql("SELECT * FROM tempDB WHERE High = 540").show
  • A query with multiple conditions. Again, all 3 notations will return the same result.
df.filter($"High" > 410 && $"Volume" < 1000000).count
df.filter("High > 410 AND Volume < 1000000").show
spark.sql("SELECT * FROM tempDB WHERE High > 410 AND Volume < 1000000").count
  • A query similar with the SQL like syntax. The _ notation is matching one character, while the % notation is matching a sequence of characters
df.filter($"Column".contains("ReGex"))
df.filter($"Column".like("_ReGex%"))
  • If there we have some missing data, we can chose to drop it. Dealing with missing data link. Let's assume we want to replace all null values in the Open column with the average of that column
df.na.fill(Map("Open" -> df.select(avg($"Open")).collect()(0).getDouble(0))).show
  • To change the date format we will use Java's SimpleDateFormat construct
df.select(date_format(df("Date"), "EEE, d MMM yyyy").as("Date"), $"Open", $"High", $"Low", $"Close", $"Volume").show
  • What is the average open price for each year, sorted by year in descending order
df.groupBy(year(df("Date")).as("Year")).mean().orderBy($"Year".desc).select("Year", "avg(Open)").show
  • In the whole data set, what percentage of time, the opening price was lower than 300
df.filter("Open < 300").count / (1.0 * df.count) * 100
  • Returning the Pearson correlation for two columns. More aggregate functions can be found at this link
scala> df.select(corr("High", "Open").as("Correlation")).show
+------------------+
|       Correlation|
+------------------+
|0.9994903683167583|
+------------------+
  • Aggregate functions over a column
scala> df.select(mean("Close").as("Avg Close"), min("Close").as("Min Close"), max("Close").as("Max Close")).show
+-----------------+---------+---------+
|        Avg Close|Min Close|Max Close|
+-----------------+---------+---------+
|385.3421456953643|     37.7|    564.1|
+-----------------+---------+---------+
  • There may be situations whereby you need to change the type of data in a column or the name of the column. Let's assume that our DataFrame called df_old has the following schema:
scala> df_old.printSchema
root
 |-- Date: string (nullable = true)
 |-- Open: string (nullable = true)
 |-- High: string (nullable = true)
 |-- Low: string (nullable = true)
 |-- Close: string (nullable = true)
 |-- Volume: string (nullable = true)

You can achieve the above schema by loading the CSV file with the .option("inferSchema", "false")
And we want to change that schema into:

scala> df_new.printSchema
root
 |-- Date: timestamp (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)
 |-- Close: double (nullable = true)
 |-- Volume: integer (nullable = true)

We will need to create a new DataFrame as shown below. For more about data types, check the link

import org.apache.spark.sql.types._

val new_df = old_df.select(
	old_df.columns.map {
		case "Date"		=> old_df("Date").cast(TimestampType).as("Date")
		case "Open"		=> old_df("Open").cast(DoubleType).as("Open")
		case "High"		=> old_df("High").cast(DoubleType).as("High")
		case "Low"		=> old_df("Low").cast(DoubleType).as("Low")
		case "Close"	=> old_df("Close").cast(DoubleType).as("Close")
		case "Volume"	=> old_df("Volume").cast(IntegerType).as("Volume")
	} : _*
)

Another way to impose the schema on a dataframe is to build it prior to loading the data.

scala> import org.apache.spark.sql.types._
scala> val schema = StructType(
	StructField(name = "A", dataType = TimestampType, nullable = false)::
	StructField(name = "B", dataType = DoubleType, nullable = true)::
	StructField(name = "C", dataType = DoubleType, nullable = false)::
	StructField(name = "D", dataType = DoubleType, nullable = true)::
	StructField(name = "E", dataType = DoubleType, nullable = true)::
	StructField(name = "F", dataType = IntegerType, nullable = false)::Nil
)

scala> schema.printTreeString
root
 |-- A: timestamp (nullable = false)
 |-- B: double (nullable = true)
 |-- C: double (nullable = false)
 |-- D: double (nullable = true)
 |-- E: double (nullable = true)
 |-- F: integer (nullable = false)

scala> val df = spark.read.option("header", "true").schema(schema).csv(file_name)
  • In order to drop full columns, we will use the drop method as shown below
scala> df.drop("Date", "Volume").show(3)
+-----+-----+-----+-----+
| Open| High|  Low|Close|
+-----+-----+-----+-----+
|490.0|493.8|481.1|492.9|
|488.6|491.0|483.5|483.8|
|484.4|487.8|484.0|486.2|
+-----+-----+-----+-----+
only showing top 3 rows

And another fancy way to drop columns

scala> import org.apache.spark.sql.functions._
scala> val dropCols = Seq("Date", "Volume")
scala> val viewCols = df.columns.diff(dropCols).map(col(_))
scala> df.select(viewCols: _*).show(3)
+-----+-----+-----+-----+
| Open| High|  Low|Close|
+-----+-----+-----+-----+
|490.0|493.8|481.1|492.9|
|488.6|491.0|483.5|483.8|
|484.4|487.8|484.0|486.2|
+-----+-----+-----+-----+
only showing top 3 rows

Disclaimer: This is by no means an original work it is merely meant to serve as a compilation of thoughts, code snippets and teachings from different sources.