Logistic Regression Using Spark Machine Learning

Posted on Posted in Blog

In this post, I demonstrate how you can use Apache Spark’s machine learning libraries to perform binary classification using logistic regression. The dataset I am using for this demo is taken from Andrew Ng’s machine learning course on Coursera.

//P.S: don’t forget to see the bonus at the end.

Score 1 Score 2 Result
69.07014406283025 52.74046973016765 1
67.94685547711617 46.67857410673128 0
70.66150955499435 92.92713789364831 1

Let’s assume a classroom scenario where students go through three exams to pass the class. The dataset have historical data of students with their scores in first two exams and a label column which shows whether each student was able to pass the 3rd and final exam or not. My goal is to train a binary classifier using historical data and predict, given scores of first two exams of a particular student, whether the student will pass the final exam(1) or not(0).

To get a better sense of data, let’s plot the scores and labels in a scatter plot using R.


In above plot, a red dot shows a passed student and black dot represents a failed one. The plot also shows a clear pattern or separation between scores of passed and failed students. My objective is to train a model/classifier that can capture this pattern and to use this model to make predictions later on. In this demo, I am going to use logistic regression algorithm to create the model.

The Complete code of this demo is available on github.

I start by defining the schema that matches the dataset.

1static StructType SCHEMA = new StructType (new StructField[] {
2new StructField(COLUMN_SCORE_1, DataTypes.DoubleType, false, Metadata.empty()),
3new StructField(COLUMN_SCORE_2, DataTypes.DoubleType, false, Metadata.empty()),
4new StructField(COLUMN_PREDICTION, DataTypes.IntegerType, false,

Then I load the data from CSV file and convert it into vectorized format while specifying feature and label columns.

1Dataset dataSet = spark .read().schema(SCHEMA).
2	format("com.databricks.spark.csv").
3	option ("header", "true").
4	load(inputFile);
6dataSet = spark.createDataFrame(dataSet.javaRDD(), SCHEMA);
7VectorAssembler vectorAssembler = new VectorAssembler().
8	setInputCols(new String[]{COLUMN_SCORE_1, COLUMN_SCORE_2}).
10dataSet = vectorAssembler.transform(dataSet);

Next step is to split the data into training and cross validation sets and setup the logistic regression classifier.

1Dataset[] splittedDataSet = dataSet.randomSplit(new double[]{0.7, 0.3},
3Dataset trainingDataSet = splittedDataSet[0];
4Dataset crossValidationDataSet = splittedDataSet[1];
6LogisticRegression logisticRegression = new LogisticRegression().

Next, I train the model and get the training results.

1LogisticRegressionModel logisticRegressionModel = 
3LogisticRegressionTrainingSummary logisticRegressionTrainingSummary =

You can also print the error on each iteration of logistic regression.

1double[] objectiveHistory = logisticRegressionTrainingSummary.objectiveHistory();
2for (double errorPerIteration : objectiveHistory)
3	System.out.println(errorPerIteration);

Next, we find the best threshold value based on FScore and use this threshold to create our final model.

1BinaryLogisticRegressionSummary binaryLogisticRegressionSummary = 
2(BinaryLogisticRegressionSummary) logisticRegressionTrainingSummary;
3// Get the threshold corresponding to the maximum F-Measure and return 
4// LogisticRegression with this selected threshold.
1Dataset fScore = binaryLogisticRegressionSummary.fMeasureByThreshold();
1double maximumFScore = fScore.select(functions.max("F-Measure")).head().getDouble(0);
1double bestThreshold = fScore.where(fScore.col("F-Measure").equalTo(maximumFScore)).select("threshold").head().getDouble(0);
1System.out.println("maximum FScore: " + maximumFScore);

Let’s use the initially separated cross validation set to find accuracy of our trained model and print the results.

1Dataset crossValidationDataSetPredictions = logisticRegressionModel.transform(crossValidationDataSet);	    
2	    JavaPairRDD<Double, Double> crossValidationPredictionRDD = convertToJavaRDDPair(crossValidationDataSetPredictions);
3    Utils.printFScoreBinaryClassfication(crossValidationPredictionRDD);
4    printPredictionResult(crossValidationDataSetPredictions);

Results look like this in my case, which are not bad for a start.

1True positives: 11
2False positives: 3
3False negatives: 0
4Precision: 0.7857142857142857
5Recall: 1.0
6FScore: 0.88
7Correct predictions: 22/25

Let’s also plot the results to get a visual intuition on how our algorithm did.


As you can see in above plot, the prediction pattern match very closely with the initial dataset we plotted which again illustrates correctness of implementation.

You can download and run the complete code including the dataset from this GitHub repository.

Bonus: You can find Random Forest based solutions of the same problem here.

Leave a Reply

Your email address will not be published. Required fields are marked *