Machine Learning with Spark (3/5) - model validation


Machine Learning with Spark

Validating a model


  • spark 2.0.0 or higher, preferable with pre-built hadoop. Download link
  • scala 2.11.8 or higher. Download link

This is a generic How To on Model Validation with Spark.

The following tutorial will be performed entirely in the spark-shell, although it is absolutely possible to wrap up everything in a function and run it as a compiled object (See this Scala tutorial).

This will be a short blog that builds on top of Part 1/5 Machine Learning with Spark, therefore I'll skip the loading data part. We assume that our data is in a dataframe called df, already in a format of two columns representing the label and the features, as shown below.

 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)

Here are the packages that we will use throughout the analysis:

import{RegressionEvaluator => RE}
import{LinearRegression => LR}
import{ParamGridBuilder => PGB, TrainValidationSplit => TVS}

Now let's split our dataframe into a training and testing subsets and instantiate a linear regression model.

val Array(training, test) = df.randomSplit(Array(.7, .3), seed = 196)
val lr = new LR()

The linear regression model has lots of parameters that we can set and tune. In order to retrieve their current settings, run the following command and you will see something similar to what is shown below.

scala> lr.extractParamMap
res105: =
        linReg_db39bbba502d-elasticNetParam: 0.0,
        linReg_db39bbba502d-featuresCol: features,
        linReg_db39bbba502d-fitIntercept: true,
        linReg_db39bbba502d-labelCol: label,
        linReg_db39bbba502d-maxIter: 100,
        linReg_db39bbba502d-predictionCol: prediction,
        linReg_db39bbba502d-regParam: 0.0,
        linReg_db39bbba502d-solver: auto,
        linReg_db39bbba502d-standardization: true,
        linReg_db39bbba502d-tol: 1.0E-6

To see a brief explanation of each of the parameters above, run lr.explainParams.

All of these parameters can be fine tuned so that our model will improve its predictions. This is where the library ParamGridBuilder comes into play. For more on hyperparameter tuning, check this link.

Let's build a grid of parameters.

val gridParams = new PGB().
	addGrid(lr.regParam, Array(.1, .2, .01, .02)).
	addGrid(lr.elasticNetParam, Array(.0, .5, .9)).
	addGrid(lr.maxIter, Array(10, 20, 30)).
	addGrid(lr.tol, Array(.1, .2, .3)).
	addGrid(lr.solver, Array("l-bfgs")).

Next we will use TrainValidationSplit library for tuning. This will evaluate each combination of parameters once and return the best model. The drawback of this approach is that, unless the dataset is not sufficiently large, the results may not be as reliable as if we were to use a CrossValidator method.

TrainValidationSplit will use the training and test subsets that we created earlier and by setting the value.setTrainRatio(.75), 75% of the data will be used for training and 25% for validation.

The estimator used in this particular case is the linear regression, the evaluator is RegressionEvaluator with r2 as the metric. Feel free to play with these values and methods. Regression Evaluator APIs, Model Tuning APIs.

val trainValidationSplit = new TVS().
	setEvaluator(new RE("r2")).

Now let's fit the model and then run it against the test data set.

val model =

val results = model.transform(test)

	select("features", "label", "prediction").

And the winner of the title The best model is ...

val best = model.bestModel

If you are interested in extracting the values of its parameters run best.extractParamMap().

To check the values of RMSE, MSE, MAE and R2:

val eval = new RE().setLabelCol("label").setPredictionCol("prediction")
println(s"RMSE: ${eval.setMetricName("rmse").evaluate(results)}")
println(s"MSE: ${eval.setMetricName("mse").evaluate(results)}")
println(s"MAE: ${eval.setMetricName("mae").evaluate(results)}")
println(s"R2: ${eval.setMetricName("r2").evaluate(results)}")

And in the end some useful links:

[+] Useful links
  • [Download Spark](
  • [Machine Learning Guide](
  • [Introduction to statistical learning](
  • [Part 1/5 Machine Learning with Spark](
  • [Part 2/5 Machine Learning with Spark](
  • [ML Tuning: model selection and hyperparameter tuning](

  • > ==Disclaimer==: This is by no means an original work it is merely meant to serve as a compilation of thoughts, code snippets and teachings from different sources.