In this post, I demonstrate how you can use Apache Spark’s machine learning libraries to perform binary classification using logistic regression. The dataset I am using for this demo is taken from Andrew Ng’s machine learning course on Coursera.

//P.S: don’t forget to see the bonus at the end.

Score 1 | Score 2 | Result |

69.07014406283025 | 52.74046973016765 | 1 |

67.94685547711617 | 46.67857410673128 | 0 |

70.66150955499435 | 92.92713789364831 | 1 |

Let’s assume a classroom scenario where students go through three exams to pass the class. The dataset have historical data of students with their scores in first two exams and a label column which shows whether each student was able to pass the 3rd and final exam or not. My goal is to train a binary classifier using historical data and predict, given scores of first two exams of a particular student, whether the student will pass the final exam(1) or not(0).

To get a better sense of data, let’s plot the scores and labels in a scatter plot using R.

In above plot, a red dot shows a passed student and black dot represents a failed one. The plot also shows a clear pattern or separation between scores of passed and failed students. My objective is to train a model/classifier that can capture this pattern and to use this model to make predictions later on. In this demo, I am going to use logistic regression algorithm to create the model.

The Complete code of this demo is available on github.

I start by defining the schema that matches the dataset.

*1*static StructType SCHEMA = new StructType (new StructField[] {
*2*new StructField(COLUMN_SCORE_1, DataTypes.DoubleType, false, Metadata.empty()),
*3*new StructField(COLUMN_SCORE_2, DataTypes.DoubleType, false, Metadata.empty()),
*4*new StructField(COLUMN_PREDICTION, DataTypes.IntegerType, false,
*5*Metadata.empty())
*6*});

Then I load the data from CSV file and convert it into vectorized format while specifying feature and label columns.

*1*Dataset dataSet = spark .read().schema(SCHEMA).
*2* format("com.databricks.spark.csv").
*3* option ("header", "true").
*4* load(inputFile);
*5*
*6*dataSet = spark.createDataFrame(dataSet.javaRDD(), SCHEMA);
*7*VectorAssembler vectorAssembler = new VectorAssembler().
*8* setInputCols(new String[]{COLUMN_SCORE_1, COLUMN_SCORE_2}).
*9* setOutputCol(COLUMN_INPUT_FEATURES);
*10*dataSet = vectorAssembler.transform(dataSet);

Next step is to split the data into training and cross validation sets and setup the logistic regression classifier.

*1*Dataset[] splittedDataSet = dataSet.randomSplit(new double[]{0.7, 0.3},
*2*SPLIT_SEED);
*3*Dataset trainingDataSet = splittedDataSet[0];
*4*Dataset crossValidationDataSet = splittedDataSet[1];
*5*
*6*LogisticRegression logisticRegression = new LogisticRegression().
*7* setMaxIter(LOGISTIC_REGRESSION_ITERATIONS).
*8* setRegParam(LOGISTIC_REGRESSION_ITERATIONS).
*9* setElasticNetParam(LOGISTIC_REGRESSION_STEP_SIZE);
*10*
*11*logisticRegression.setLabelCol(COLUMN_PREDICTION);
*12*logisticRegression.setFeaturesCol(COLUMN_INPUT_FEATURES);

Next, I train the model and get the training results.

*1*LogisticRegressionModel logisticRegressionModel =
*2*logisticRegression.fit(trainingDataSet);
*3*LogisticRegressionTrainingSummary logisticRegressionTrainingSummary =
*4*logisticRegressionModel.summary();

You can also print the error on each iteration of logistic regression.

*1*double[] objectiveHistory = logisticRegressionTrainingSummary.objectiveHistory();
*2*for (double errorPerIteration : objectiveHistory)
*3* System.out.println(errorPerIteration);

Next, we find the best threshold value based on FScore and use this threshold to create our final model.

*1*BinaryLogisticRegressionSummary binaryLogisticRegressionSummary =
*2*(BinaryLogisticRegressionSummary) logisticRegressionTrainingSummary;
*3*// Get the threshold corresponding to the maximum F-Measure and return
*4*// LogisticRegression with this selected threshold.
*1*Dataset fScore = binaryLogisticRegressionSummary.fMeasureByThreshold();
*1*double maximumFScore = fScore.select(functions.max("F-Measure")).head().getDouble(0);
*1*double bestThreshold = fScore.where(fScore.col("F-Measure").equalTo(maximumFScore)).select("threshold").head().getDouble(0);
*1*logisticRegressionModel.setThreshold(bestThreshold);
*1*System.out.println("maximum FScore: " + maximumFScore);

Let’s use the initially separated cross validation set to find accuracy of our trained model and print the results.

*1*Dataset crossValidationDataSetPredictions = logisticRegressionModel.transform(crossValidationDataSet);
*2* JavaPairRDD<Double, Double> crossValidationPredictionRDD = convertToJavaRDDPair(crossValidationDataSetPredictions);
*3* Utils.printFScoreBinaryClassfication(crossValidationPredictionRDD);
*4* printPredictionResult(crossValidationDataSetPredictions);

Results look like this in my case, which are not bad for a start.

*1*True positives: 11
*2*False positives: 3
*3*False negatives: 0
*4*Precision: 0.7857142857142857
*5*Recall: 1.0
*6*FScore: 0.88
*7*Correct predictions: 22/25

Let’s also plot the results to get a visual intuition on how our algorithm did.

As you can see in above plot, the prediction pattern match very closely with the initial dataset we plotted which again illustrates correctness of implementation.

You can download and run the complete code including the dataset from this GitHub repository.

**Bonus:** You can find Random Forest based solutions of the same problem here.