Sign in

Churn prediction in Sparkify, a digital music service

Sparkify music service

Music streaming services have become very popular in the past years between people who enjoys listening to music and want to have access to a large music library. These services usually offer two kind of accounts, one free that place advertisements between the songs and the other, the premium one, where music can be streamed nonstop after paying a monthly fee rate.

From a company perspective, it is more profitable to have more premium users than free ones. This can be achieved by incrementing the number of users that pass to a premium plan and, at the same time, maintaining the current premium users to avoid that they cancel their account in the service. This last concept, also known as churn, is very valuable for companies because if they can predict when a user is going to cancel their plans, they could offer some discounts or promotions that retain the user and then save the company considerable money in revenues.

In our digital music service, Sparkify, every time the user interacts with it, for example playing songs, logging out, liking in a song or downgrading the service, it generates data that is collected. This data will be used to predict when the user will more likely churn in our service.

The process followed to get the desired predictions is divided in 5 steps:

I. Exploratory Data Analysis. Churn Definition

Here are included the available features, the values they can take and what is defined as churn. Also finding missing or empty values is discussed and how they are treated.

II. Visualisation

In this section some features are plotted distinguishing between two groups, users who churned and those who did not.

III. Feature Engineering

Next the features that are used to build and train the models are selected here

IV. Modeling

The models that are used and trained are five: Random Forest, Logistic Regression, Decision Tree, Gradient Boosted Trees and LinearSVC

V. Model Evaluation and Validation

In the final step an analysis of the results obtained for each model is done as well as the parameters for the best performing models

1. Exploratory Data Analysis. Churn Definition

Let’s start by having a general view of the data to analyse that is available.

The dataset consist in 286,500 different logs each one with values for 18 features as can be seen on the picture on the left side.

Between the values that each feature can take, those related to userId, level (paid or free), gender (M, F), page (NextSong, Submit Downgrade, Add Friend,…) or ts (timestamp) among others, will be the values used to build the model and make the churn prediction.

The moment a user sends a Cancellation Confirmation will be considered as churn and it will be marked with a 1, otherwise they will receive a 0. With this definition in mind, the dataset have a total of 173 users that did not churn while 52 submitted the Cancellation Confirmation.

Data Preprocessing

Since userId is the feature that identify our users, it is important that we do not have any missing or empty values. To check that we can execute the following code:

Missing Values

for value in sparkify_events.columns:
print ("In column (",value,") the number of missing values is", sparkify_events.where(sparkify_events[value].isNull()).count())
--> In column ( userId ) the number of missing values is 0

Since no rows in userId have missing values, let’s check for empty values

Empty Values

for value in sparkify_events.columns:
print (“In column (“,value,”) the number of empty values is”, sparkify_events.filter(sparkify_events[value] == ‘’).count())
--> In column ( userId ) the number of empty values is 8346

Here the number of missing values is 8346. This can be treated by just selecting the rows with no empty values for userId:

sparkify_events = sparkify_events.filter(sparkify_events.userId != “”)

After this execution the number of missing or empty values for userId is 0

2. Visualisation

When visualising the data available, the focus will be on the distribution of users who churn or not depending on their gender, Thumbs up interactions, NextSong and time passed until service cancellation.

From the picture above it seems that male users have a higher tendency than females to cancel the service.

Another important feature to look into is the “page” because is where the user interactions are recorded. For example when they logout, ask for help, upgrade or they unlike a song, they will generate a labeled interaction. As can be seen on the left side picture, and as expected for a musical service, the event that has a bigger role is “NextSong” with over 200,000 records followed by “Thumbs Up” and their representation by each userId can be seen in the following pictures:

On the left one it is clear that the more songs the user pass through, higher the probability is that the account will not end up in churn. Something similar occurs when users use the like function. As the right side picture depicts, when Thumbs Up is more used, less likely is that users cancel their accounts.

Finally, taking a look at the time that passes since users register until they cancel the service, the picture below illustrates that usually takes around 50 days.

3. Feature Engineering

Once finished representing the data, several characteristics will be selected and then joined in one dataframe to train the model that will predict churn. More specifically, the following 15 features will be used:

  1. Listened songs
  2. Average played songs per session
  3. Songs added to playlist
  4. Total different artists listened
  5. Average played time per session
  6. Sessions of the user
  7. Time passed since registration
  8. Friends number
  9. Thumbs up
  10. Thumbs down
  11. Help requests
  12. Errors
  13. Downgrade
  14. Gender
  15. Free or paid user

4. Modeling

After the features have been selected, five different Machine Learning algorithms will be trained using hyperparameter tuning (ParamGrid in Pyspark library) for selecting the best parameters with Cross Validation, and the metrics accuracy and f1 score will be used to determine how good the models perform.

Accuracy is a ratio of correctly predicted observations compared to the total number of them and is usually used when having a balanced dataset (for example if we had a dataset with similar number of users that did and did not churn)

f1 score is also a ratio that consider both false positives and false negatives. It is effective when the dataset is unbalanced, such as the one used in this project. This is because the aim is to correctly identify those users who might leave but didn’t (false positive) and save in discounts or offers, and also don’t miss out those who will actually churn (false negatives), in order to give them the offers or discounts to stay.


First the different models trained will be displayed along with the hyperparameters used and later the results of them will be displayed together. In order to execute all of them in a sequence a function is defined to train and save the values of the models

Random Forest

The hyperparameters used in this model were: maxDepth = [5, 10], numTrees = [10, 20] and impurity = [‘entropy’, ‘gini’]

model = RandomForestClassifier(labelCol = “label”, featuresCol = “features”)
paramGrid = ParamGridBuilder().addGrid(model.maxDepth,[5, 10]).\
addGrid(model.numTrees, [10, 20]).\
addGrid(model.impurity, [‘entropy’, ‘gini’]).build()


The metrics obtained are:
Accuracy: 0.800000
f1 score: 0.736508
Training time: 04:34.036536 min.
Best depth 10 and best number of trees 20

Logistic Regression

The hyperparameters used in this model were: elasticNetParam = [0.0, 0.2, 0.6, 1.0] and regParam = [0.02, 0.06, 0.1]

model = LogisticRegression(labelCol = “label”, featuresCol = “features”, maxIter = 10)
paramGrid = ParamGridBuilder().addGrid(model.elasticNetParam, [0.0, 0.2, 0.6, 1.0]).addGrid(model.regParam, [0.02, 0.06, 0.1]).build()


The metrics obtained are:
Accuracy: 0.885714
f1 score: 0.860829
Training time: 04:08.942884 min.
Best elasticNetParam 0.6 and best regParam 0.1

Decision Tree

The hyperparameters used in this model were: impurity = [‘entropy’, ‘gini’] and maxDepth = [2, 3, 5, 8, 13, 21, 30]

model = DecisionTreeClassifier(labelCol = “label”, featuresCol = “features”, seed = 42)
paramGrid = ParamGridBuilder().addGrid(model.impurity, [‘entropy’, ‘gini’]).addGrid(model.maxDepth, [2, 3, 5, 8, 13, 21, 30]).build()


The metrics obtained are:
Accuracy: 0.800000
f1 score: 0.805938
Training time: 04:47.750021 min.
Best depth 30

Gradient Boosted Trees

The hyperparameters used in this model were: maxDepth = [3, 5, 7] and maxIter = [6, 12]

model = GBTClassifier(labelCol = “label”, featuresCol = “features”, maxIter = 10, seed = 42)
paramGrid = ParamGridBuilder().addGrid(model.maxDepth, [3, 5, 7]).addGrid(model.maxIter, [6, 12]).build()


The metrics obtained are:
Accuracy: 0.685714
f1 score: 0.717108
Training time: 04:24.014526 min.
Best depth 7 and best iterations 12


The hyperparameters used in this model were: maxIter = [6, 12] and regParam = [0.02, 0.06, 0.1]

model = LinearSVC(labelCol = “label”, featuresCol = “features”)
paramGrid = ParamGridBuilder().addGrid(model.maxIter, [6, 12]). addGrid(model.regParam, [0.02, 0.06, 0.1]).build()


The metrics obtained are:
Accuracy: 0.828571
f1 score: 0.750893
Training time: 04:04.784299
Best iteration 6 and best regParam 0.06

5. Model Evaluation and Validation

The results obtained are gathered in the following table:

Table with results obtained

In average the training of the models take around 4 minutes for each one and Logistic Regression is the one performing better with an accuracy of 0.885 and a f1 score of 0.861

Focusing on the best performing model, the best parameters are elasticNetParam = 0.6 and regParam = 0.1 and the features that have more impact are registration_min, errors, friend, played_time_session, avg_songs_session and thumbs_down.

From them, the more negative the value of these features are, most likely the user will stay in the service and will not churn so for example if a user spends more time hearing songs or adding friends, their probability of churn is lower.

This can be easily seen in the pictures above. That a user spent more time playing songs and had more friends seems to be correlated to a less probability of churn.

Taking into account that the data was divided in a test and train set and that the model did not see the data in the test set before, an approximately value of 0.86 for the f1 score seems to be a good mark for the churn prediction task that we were trying to develop.


In order to predict the churn of users in our music service, five different models are trained. Between them, the best model throws an accuracy of 0.885714 and a f1 score of 0.860829. This last score is important since the dataset used is unbalanced with a total of 173 users that did not churn while 52 churned. From the model performance perspective, we are predicting well the number of users that will churn and then we could implement some future measures to avoid that they abandon the music service. This will allow us to save money in revenues. Because of the limitation in computational capacity, the number of parameters used in the tuning of the model was not very big and therefore the accuracy and f1 score could have increased their values if more parameters would have been tried.


A final overview of the project is collected in this section

End to end summary

In this project a dataset of a music streaming service is used to predict when a user will churn, that is, when a user will cancel their account. To begin with, a look at the features and possible values is displayed. Also missing and empty values for the userId feature are deleted. Next the different features that are going to be used in the models are selected and joined in one dataset, a total of fifteen. The number of different Machine Learning models are five: Random Forest, Logistic Regression, Decision Tree, Gradient Boosted Trees and LinearSVC. All of them are trained using Cross Validation and for finding the best parameters hyperparameter tuning is the chosen option. From the results obtained the Logistic Regression model seems to perform better.


Between the difficulties encountered, using the PySpark library to program the whole project, together with the features definition and join to use them in the models, were the most important ones.


As a future line of study and to improve the models performance, it would be good to try with a bigger dataset where the models have more data to learn from and use more parameters in Grid Search to tune the models. Also collect more data from the service is a good measure. These actions would allow to increment the accuracy of the predictions and be more effective therefore in managing the churn problem.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store