Churn Prediction using PySpark

Predicting User Churn for music streaming app

Project Overview

The project aims to explore user interaction data to identify customer churn. This problem specifically focuses on handling large datasets using Spark. PySpark is used to clean, wrangle and process data and perform modeling and tuning to build a churn prediction model.

Problem Statement

Sparkify is a music streaming company, like Spotify. The dataset contains logs of user interaction. The full dataset is 12 GB. We are working with a smaller dataset of 128MB. Using the user logs, we need to identify customers with propensity to churn so they can be offered promotions. We can identify factors which are significant indicators for churn.

Exploratory Data Analysis

The dataset contains a log of 286,500 user interactions. It has information about the user, the interactions, timestamps, device used by the user, etc. The schema of the data is as below:

Data Schema

The dataset snapshot is as below:

Data Cleaning

We clean up the dataset to remove records with missing userId and sessionId and where userId is empty, we have 278,154 records, since most of these records were for users who had not yet logged in or were looking to sign up. We also convert the date of registration and current time from timestamp to more understandable date-time format. The page feature lists all the possible user actions in the dataset.

User Actions

We define Churn as when a user performs Cancellation confirmation action. There are two subscription levels in the service — Free and Paid. A user can upgrade or downgrade their subscription level. We create a flag to identify when a user downgrades their account and if a user churned.


We look at distribution of some features to identify patterns in the data.

Distribution of Churned Users

The distribution of churned users in the dataset shows that the dataset is heavily imbalanced. While this is standard for churn analysis problems, we will need to account for this during our modeling. We will need to either balance our dataset by under-sampling or choosing the appropriate metrics to account for minority class better.

Distribution by Subscription level

There are a lot more free users for our music streaming service than paid users. We should also check if there is a predilection for churn among users of a given subscription level.

Churn by Subscription level

We notice that there is a higher chance for paid users churning than free users. This could be an important factor for predicting user churn.

Gender of users

We have more male users than female users but the numbers don’t differ by a lot, so we can consider our dataset as mostly balanced wrt gender.

User churn by gender

We notice that the there is a higher propensity for male users to churn, compared to female users. We could also analyze the trend with subscription level in the mix.

Number of songs listened per session

We see that free tier users listen to much less songs every session as compared to paid users. Also users who did not churn listen to slightly more songs per session on average, especially the paid tier users.

Finally let’s look at the most popular artists on our streaming service.

Feature Engineering

Since the dataset contains user interactions, we need to aggregate them at a user level to build user profile that will help identify propensity to churn.

Based on the features available in the dataset and our understanding of user behaviour, we create features that would represent user engagement, user profile, etc.

We identify features like user lifetime, gender of the user, etc.

User Lifetime feature

We also create features for user engagement like number of friends added, number of songs added to playlist, thumbs up and thumbs down given to songs.

Number of friends added

We also create features related to their music listening behaviour like number of songs listened to, number of artists the user listens to, number of songs per session, etc

Number of songs listened to by user

Once we create these features, we merge them for each user to create the final database that we will use for modeling.

Final Dataframe

Next we move to the modeling part.


Note: The medium-size dataset of 230 MB was used for the modeling part and training was done on IBM Watson cluster, while the EDA was done on the mini dataset of 128 MB.

Before we start modeling on our dataset, we need to convert our features to numeric features and then we scale the features. Scaling the dataset or not depends on the models that we plan to use. We will also split our dataset into train-validation-test sets to ensure accurate assessment of model performance. I used a 70:15:15 split for the dataset.

Now that we have our dataset ready, we will test it against a baseline model. Before we build our models, we define the problem as a Binary classification problem where 0 represents user does not churn and 1 represents that user churns. Also, since our dataset is imbalanced, accuracy might not be the best measure as we could predict all values as 0 and have a high accuracy and yet poor recall. So, to balance this, we use F1 as our metric.

I take a dummy classifier is one that always predicts that user will not churn. We could also have a baseline model that randomly predicts that user will churn. The results for our dummy classifier are as below:

Dummy Classifier

We try other models like Logistic Regression, Random Forest Classifier, Linear Support Vector and Gradient-Boosted Tree Classifier. We fit the model on the training set and make predictions for validation dataset. We then choose with the best F1 score. In our case it was the Random Forest Classifier.

Logistic Regression metrics
Gradient Boosted Tree Classifier metrics

Next, we tune our Random Forest Classifier model. Once we have a final model, we use it to benchmark model performance by making predictions on test set.

Test Set results for final model

We can have a look at the feature importance of our model.

Feature Importance

We see that user_lifetime has by far the most impact on predicting churn.


We explored churn prediction problem using CRISP-DM process. This is very similar to real problems that adds business value. We can retain users and increase revenue for company by strategically implementing insights gleaned from the exploration and modeling.

Accuracy is not a reliable metric for a highly imbalanced dataset like this, so we used F1 score. We could also have tried SMOTE for oversampling or random undersampling of our dataset to balance the classes.

We can extract other features like device used by users and see their effect on churn prediction.

We tried the modeling exercise on small and medium dataset. We can explore the full 12GB dataset. Having more records for the minority class would improve the predictive power.

The full code for the project can be found on my Github.

Churn Prediction using Spark was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.