Multi-class Text Classification using Spark ML in Python
Build a News Article Classifier using the PySpark framework
Introduction
We are dealing with datasets that are too large or complex to be handled by traditional data-processing application software. In this tutorial, we will use a general-purpose distributed data processing engine called Apache Spark to process our data and build a machine-learning model for text classification.
What is Apache Spark?
Apache Spark is a multi-language engine for executing data engineering, data science, and machine learning on single-node machines or clusters.
It is a tool for data processing that can quickly process operations on very large data sets and distribute operations over multiple machines.
What is PySpark?
PySpark is an interface for Apache Spark in Python. It allows you to write Spark applications using Python APIs and also provides the PySpark shell for interactively analyzing your data in a distributed environment.
PySpark supports most of Spark’s features such as Spark SQL, DataFrame, Streaming, MLlib (Machine Learning), and Spark Core.
Objective 🎯
To implement a machine learning model that analyses the news headline or the news’s content and categorizes it as tech, political, sports, etc. It helps news websites automatically label articles and display them in their respective sections.
Data
This dataset consists of 120,000 training sample news articles containing 3 columns. The first column is Class Id, the second column is Title and the third column is Description.
The class ids are 1– 4 where 1 represents World, 2 represents Sports, 3 represents Business, and 4 represents Science.
Download dataset: kaggle.com
Imports
- Install pyspark using pip:
!pip install pyspark
# Importing necessary libraries
import seaborn as sns
import matplotlib.pyplot as plt
from pyspark.ml import Pipeline # pipeline to transform data
from pyspark.sql import SparkSession # to initiate spark
from pyspark.sql.types import FloatType
from pyspark.ml.feature import RegexTokenizer # tokenizer
from pyspark.ml.feature import HashingTF, IDF # vectorizer
from pyspark.ml.feature import StopWordsRemover # to remove stop words
from pyspark.sql.functions import concat_ws, col # to concatinate cols
from pyspark.ml.classification import LogisticRegression # ml model
from pyspark.ml.evaluation import MulticlassClassificationEvaluator # to evaluate the model
from pyspark.mllib.evaluation import MulticlassMetrics # # performance metrics
- Creating a Spark session and loading the dataset
# create a new spark session
spark = SparkSession.builder.master("local[*]")\
.appName("news classfier")\
.getOrCreate()
# show session details
spark
Loading dataset
Class Index
— Category of the news
1 — World
2 — Sports
3 — Business
4 — Science
Title
— Title of the article
Description
— Content in the article
# load dataset
df = spark.read.csv("train.csv", inferSchema=True, header=True)
# shows top 20 rows
df.show()
# Output
# +-----------+--------------------+--------------------+
# |Class Index| Title| Description|
# +-----------+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|Reuters - Short-s...|
# | 3|Carlyle Looks Tow...|Reuters - Private...|
# | 3|Oil and Economy C...|Reuters - Soaring...|
# | 3|Iraq Halts Oil Ex...|Reuters - Authori...|
# | 3|Oil prices soar t...|AFP - Tearaway wo...|
# | 3|Stocks End Up, Bu...|Reuters - Stocks ...|
# | 3|Money Funds Fell ...|AP - Assets of th...|
# | 3|Fed minutes show ...|USATODAY.com - Re...|
# | 3|Safety Net (Forbe...|"Forbes.com - Aft...|
# | 3|Wall St. Bears Cl...| NEW YORK (Reuter...|
# | 3|Oil and Economy C...| NEW YORK (Reuter...|
# | 3|No Need for OPEC ...| TEHRAN (Reuters)...|
# | 3|Non-OPEC Nations ...| JAKARTA (Reuters...|
# | 3|Google IPO Auctio...| WASHINGTON/NEW Y...|
# | 3|Dollar Falls Broa...| NEW YORK (Reuter...|
# | 3|Rescuing an Old S...|If you think you ...|
# | 3|Kids Rule for Bac...|The purchasing po...|
# | 3|In a Down Market,...|There is little c...|
# | 3|US trade deficit ...|The US trade defi...|
# | 3|Shell 'could be t...|Oil giant Shell c...|
# +-----------+--------------------+--------------------+
# only showing top 20 rows
- Concatenate textual columns to get a single text column
# Renaming 'Class Index' col to 'label'
df = df.withColumnRenamed('Class Index', 'label')
# Add a new column 'Text' by concatinating 'Title' and 'Description'
df = df.withColumn("Text", concat_ws(" ", "Title", 'Description'))
# Remove old text columns
df = df.select('label', 'Text')
# Shows top 10 rows
df.show(10)
# Output: concated text features
# +-----+--------------------+
# |label| Text|
# +-----+--------------------+
# | 3|Wall St. Bears Cl...|
# | 3|Carlyle Looks Tow...|
# | 3|Oil and Economy C...|
# | 3|Iraq Halts Oil Ex...|
# | 3|Oil prices soar t...|
# | 3|Stocks End Up, Bu...|
# | 3|Money Funds Fell ...|
# | 3|Fed minutes show ...|
# | 3|Safety Net (Forbe...|
# | 3|Wall St. Bears Cl...|
# +-----+--------------------+
# only showing top 10 rows
ML Pipeline
- Tokenizer — To convert the sentences to a list of words also known as tokens.
# convert sentences to list of words
tokenizer = RegexTokenizer(inputCol="Text", outputCol="words", pattern="\\W")
# adds a column 'words' to df after tokenization
df = tokenizer.transform(df)
df.select(['label','Text', 'words']).show(5)
# Output: tokenized text
# +-----+--------------------+--------------------+
# |label| Text| words|
# +-----+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|[wall, st, bears,...|
# | 3|Carlyle Looks Tow...|[carlyle, looks, ...|
# | 3|Oil and Economy C...|[oil, and, econom...|
# | 3|Iraq Halts Oil Ex...|[iraq, halts, oil...|
# | 3|Oil prices soar t...|[oil, prices, soa...|
# +-----+--------------------+--------------------+
# only showing top 5 rows
- Stopwords remover — to remove stop words or meaningless words from the corpus (is, this, that, etc..)
# to remove stop words like is, the, in, etc.
stopwords_remover = StopWordsRemover(inputCol="words", outputCol="filtered")
# adds a column 'filtered' to df without stopwords
df = stopwords_remover.transform(df)
df.select(['label','Text', 'words', 'filtered']).show(5)
#Output:
# +-----+--------------------+--------------------+--------------------+
# |label| Text| words| filtered|
# +-----+--------------------+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|[wall, st, bears,...|[wall, st, bears,...|
# | 3|Carlyle Looks Tow...|[carlyle, looks, ...|[carlyle, looks, ...|
# | 3|Oil and Economy C...|[oil, and, econom...|[oil, economy, cl...|
# | 3|Iraq Halts Oil Ex...|[iraq, halts, oil...|[iraq, halts, oil...|
# | 3|Oil prices soar t...|[oil, prices, soa...|[oil, prices, soa...|
# +-----+--------------------+--------------------+--------------------+
# only showing top 5 rows
- HashingTF — Calculating the Term Frequency of the words in the corpus
# Calculate term frequency in each article
hashing_tf = HashingTF(inputCol="filtered",
outputCol="raw_features",
numFeatures=10000)
# adds raw tf features to df
featurized_data = hashing_tf.transform(df)
- IDF Vectorizer —Inverse document frequency (IDF).
The standard formulation is used: IDF = log((m + 1) / (d(t) + 1)), where m is the total number of documents and d(t) is the number of documents that contain the term t.
# Inverse document frequency
idf = IDF(inputCol="raw_features", outputCol="features")
idf_vectorizer = idf.fit(featurized_data)
# converting text to vectors
rescaled_data = idf_vectorizer.transform(featurized_data)
# top 20 rows
rescaled_data.select("label",'Text', 'words', 'filtered', "features").show()
- Train Test Split: keep 75% data as the train set and 25% data as a test set
# Split Train/Test data
(train, test) = rescaled_data.randomSplit([0.75, 0.25], seed = 202)
print("Training Dataset Count: " + str(train.count()))
print("Test Dataset Count: " + str(test.count()))
# Output
# Training Dataset Count: 90149
# Test Dataset Count: 29851
- LogisticRegression — Simplest and power ml algorithm for text classification. Algorithms in the Pyspark ML library support multi-class problems by passing the
family
parameter as ‘multinomial’ while creating a model object.
# model object
lr = LogisticRegression(featuresCol='features',
labelCol='label',
family="multinomial",
regParam=0.3,
elasticNetParam=0,
maxIter=50)
# train model with default parameters
lrModel = lr.fit(train)
# get predictions for test set
predictions = lrModel.transform(test)
# show top 20 predictions
predictions.select("Text", 'probability','prediction', 'label').show()
# Output
# +--------------------+--------------------+----------+-----+
# | Text| probability|prediction|label|
# +--------------------+--------------------+----------+-----+
# | #39;Batman #39; ...|[4.79798540233463...| 1.0| 1|
# | #39;Black boxes ...|[6.08123465944845...| 1.0| 1|
# | #39;Deserter #39...|[3.25929743141170...| 1.0| 1|
# | #39;Patience, pe...|[1.50809323657597...| 1.0| 1|
# | #39;She won #39;...|[2.22525102692627...| 1.0| 1|
# | #39;Suspect Pack...|[3.75379098074928...| 4.0| 1|
# | #39;The Scream #...|[4.44907744471914...| 1.0| 1|
# | #39;The world is...|[3.50140661367225...| 1.0| 1|
# |"Australian oppos...|[1.04995173599668...| 1.0| 1|
# |"Europeans expect...|[2.47196376832162...| 4.0| 1|
# |"US dismisses Nor...|[8.76768106766652...| 1.0| 1|
# |'Alien Vs. Predat...|[6.05476093864577...| 4.0| 1|
# |'Boots' Takes Gol...|[3.44151333177751...| 2.0| 1|
# |'Cold-Blooded Kil...|[1.54645775278094...| 1.0| 1|
# |'Distressed' That...|[1.68591918798567...| 1.0| 1|
# |'Dozens killed' i...|[1.20363007026082...| 1.0| 1|
# |'Hawaii Five-0' M...|[6.97588872950026...| 4.0| 1|
# |'Mock executions'...|[3.10698789262081...| 1.0| 1|
# |'On Death and Dyi...|[7.13012837580222...| 4.0| 1|
# |'Potential Develo...|[9.20192968773768...| 1.0| 1|
# +--------------------+--------------------+----------+-----+
# only showing top 20 rows
Model Evaluation
- Accuracy — Achieved 90% of test accuracy
# to evalute model
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
# print test accuracy
print("Test-set Accuracy is : ", evaluator.evaluate(predictions))
# Output
# Test-set Accuracy is : 0.8964532714370209
- Confusion matrix — It is the visual representation of the Actual VS Predicted values. It measures the performance of our Machine Learning classification model.
labels = ["World", "Sports", "Business","Science"]
# important: need to cast to float type, and order by prediction, else it won't work
preds_and_labels = predictions.select(['prediction','label']) \
.withColumn('label', col('label') \
.cast(FloatType())) \
.orderBy('prediction')
# generate metrics
metrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))
# figure object
_ = plt.figure(figsize=(7, 7))
# plot confusion matrix
sns.heatmap(metrics.confusionMatrix().toArray(),
cmap='viridis',
annot=True,fmt='0',
cbar=False,
xticklabels=labels,
yticklabels=labels)
plt.show()
ML Pipeline
— we can convert all the above steps into a single pipeline to make it easy to run and more readable. Pyspark provides a Pipeline
API to build the pipeline using a series of transformers and a final estimator to get predictions (refer to figure 1).
# load dataset
df = spark.read.csv("train.csv", inferSchema=True, header=True)
# Renaming 'Class Index' col to 'label'
df = df.withColumnRenamed('Class Index', 'label')
# Add a new column 'Text' by concatinating 'Title' and 'Description'
df = df.withColumn("Text", concat_ws(" ", "Title", 'Description'))
# Select new text feature and labels
df = df.select('label', 'Text')
# tokenizer
tokenizer = RegexTokenizer(inputCol="Text", outputCol="words", pattern="\\W")
# stopwords
stopwords_remover = StopWordsRemover(inputCol="words", outputCol="filtered")
# term frequency
hashing_tf = HashingTF(inputCol="filtered",
outputCol="raw_features",
numFeatures=10000)
# Inverse Document Frequency - vectorizer
idf = IDF(inputCol="raw_features", outputCol="features")
# model
lr = LogisticRegression(featuresCol='features',
labelCol='label',
family="multinomial",
regParam=0.3,
elasticNetParam=0,
maxIter=50)
# Put everything in pipeline
pipeline = Pipeline(stages=[tokenizer,
stopwords_remover,
hashing_tf,
idf,
lr])
# Fit the pipeline to training documents.
pipelineFit = pipeline.fit(df)
# transform add train
dataset = pipelineFit.transform(df)
# show top 10 predictions
dataset.show(10)
Conclusion
In this tutorial, we developed a text classification model on a news dataset using the Pyspark library.
We have used pyspark transformers to vectorize our text. We trained a Logistic Regression model and got 90% of test accuracy and plot confusion matrix. In the end, we put all the preprocessing and modeling steps in a pipeline to run our predictions easily.
Let’s connect: Twitter | LinkedIn | GitHub
My other tutorials: