Multi-class Text Classification using Spark ML in Python

Ashok kumar Palivela
8 min readNov 25, 2022

--

Build a News Article Classifier using the PySpark framework

Image by Author

Introduction

We are dealing with datasets that are too large or complex to be handled by traditional data-processing application software. In this tutorial, we will use a general-purpose distributed data processing engine called Apache Spark to process our data and build a machine-learning model for text classification.

What is Apache Spark?

Apache Spark is a multi-language engine for executing data engineering, data science, and machine learning on single-node machines or clusters.

It is a tool for data processing that can quickly process operations on very large data sets and distribute operations over multiple machines.

What is PySpark?

PySpark is an interface for Apache Spark in Python. It allows you to write Spark applications using Python APIs and also provides the PySpark shell for interactively analyzing your data in a distributed environment.

PySpark supports most of Spark’s features such as Spark SQL, DataFrame, Streaming, MLlib (Machine Learning), and Spark Core.

Image: spark.apache.org/docs/latest/api/python

Objective 🎯

To implement a machine learning model that analyses the news headline or the news’s content and categorizes it as tech, political, sports, etc. It helps news websites automatically label articles and display them in their respective sections.

Data

This dataset consists of 120,000 training sample news articles containing 3 columns. The first column is Class Id, the second column is Title and the third column is Description.

The class ids are 1– 4 where 1 represents World, 2 represents Sports, 3 represents Business, and 4 represents Science.

Download dataset: kaggle.com

Imports

  • Install pyspark using pip: !pip install pyspark
# Importing necessary libraries 
import seaborn as sns
import matplotlib.pyplot as plt

from pyspark.ml import Pipeline # pipeline to transform data
from pyspark.sql import SparkSession # to initiate spark
from pyspark.sql.types import FloatType
from pyspark.ml.feature import RegexTokenizer # tokenizer
from pyspark.ml.feature import HashingTF, IDF # vectorizer
from pyspark.ml.feature import StopWordsRemover # to remove stop words
from pyspark.sql.functions import concat_ws, col # to concatinate cols
from pyspark.ml.classification import LogisticRegression # ml model
from pyspark.ml.evaluation import MulticlassClassificationEvaluator # to evaluate the model
from pyspark.mllib.evaluation import MulticlassMetrics # # performance metrics
  • Creating a Spark session and loading the dataset
# create a new spark session
spark = SparkSession.builder.master("local[*]")\
.appName("news classfier")\
.getOrCreate()

# show session details
spark
pyspark version v3.3.1

Loading dataset

Class Index — Category of the news
1 — World
2 — Sports
3 — Business
4 — Science

Title — Title of the article

Description — Content in the article

# load dataset
df = spark.read.csv("train.csv", inferSchema=True, header=True)

# shows top 20 rows
df.show()

# Output
# +-----------+--------------------+--------------------+
# |Class Index| Title| Description|
# +-----------+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|Reuters - Short-s...|
# | 3|Carlyle Looks Tow...|Reuters - Private...|
# | 3|Oil and Economy C...|Reuters - Soaring...|
# | 3|Iraq Halts Oil Ex...|Reuters - Authori...|
# | 3|Oil prices soar t...|AFP - Tearaway wo...|
# | 3|Stocks End Up, Bu...|Reuters - Stocks ...|
# | 3|Money Funds Fell ...|AP - Assets of th...|
# | 3|Fed minutes show ...|USATODAY.com - Re...|
# | 3|Safety Net (Forbe...|"Forbes.com - Aft...|
# | 3|Wall St. Bears Cl...| NEW YORK (Reuter...|
# | 3|Oil and Economy C...| NEW YORK (Reuter...|
# | 3|No Need for OPEC ...| TEHRAN (Reuters)...|
# | 3|Non-OPEC Nations ...| JAKARTA (Reuters...|
# | 3|Google IPO Auctio...| WASHINGTON/NEW Y...|
# | 3|Dollar Falls Broa...| NEW YORK (Reuter...|
# | 3|Rescuing an Old S...|If you think you ...|
# | 3|Kids Rule for Bac...|The purchasing po...|
# | 3|In a Down Market,...|There is little c...|
# | 3|US trade deficit ...|The US trade defi...|
# | 3|Shell 'could be t...|Oil giant Shell c...|
# +-----------+--------------------+--------------------+
# only showing top 20 rows
  • Concatenate textual columns to get a single text column
# Renaming 'Class Index' col to 'label'
df = df.withColumnRenamed('Class Index', 'label')

# Add a new column 'Text' by concatinating 'Title' and 'Description'
df = df.withColumn("Text", concat_ws(" ", "Title", 'Description'))

# Remove old text columns
df = df.select('label', 'Text')

# Shows top 10 rows
df.show(10)

# Output: concated text features
# +-----+--------------------+
# |label| Text|
# +-----+--------------------+
# | 3|Wall St. Bears Cl...|
# | 3|Carlyle Looks Tow...|
# | 3|Oil and Economy C...|
# | 3|Iraq Halts Oil Ex...|
# | 3|Oil prices soar t...|
# | 3|Stocks End Up, Bu...|
# | 3|Money Funds Fell ...|
# | 3|Fed minutes show ...|
# | 3|Safety Net (Forbe...|
# | 3|Wall St. Bears Cl...|
# +-----+--------------------+
# only showing top 10 rows

ML Pipeline

  • Tokenizer — To convert the sentences to a list of words also known as tokens.
# convert sentences to list of words
tokenizer = RegexTokenizer(inputCol="Text", outputCol="words", pattern="\\W")

# adds a column 'words' to df after tokenization
df = tokenizer.transform(df)

df.select(['label','Text', 'words']).show(5)

# Output: tokenized text
# +-----+--------------------+--------------------+
# |label| Text| words|
# +-----+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|[wall, st, bears,...|
# | 3|Carlyle Looks Tow...|[carlyle, looks, ...|
# | 3|Oil and Economy C...|[oil, and, econom...|
# | 3|Iraq Halts Oil Ex...|[iraq, halts, oil...|
# | 3|Oil prices soar t...|[oil, prices, soa...|
# +-----+--------------------+--------------------+
# only showing top 5 rows
  • Stopwords remover — to remove stop words or meaningless words from the corpus (is, this, that, etc..)
# to remove stop words like is, the, in, etc.
stopwords_remover = StopWordsRemover(inputCol="words", outputCol="filtered")

# adds a column 'filtered' to df without stopwords
df = stopwords_remover.transform(df)

df.select(['label','Text', 'words', 'filtered']).show(5)

#Output:
# +-----+--------------------+--------------------+--------------------+
# |label| Text| words| filtered|
# +-----+--------------------+--------------------+--------------------+
# | 3|Wall St. Bears Cl...|[wall, st, bears,...|[wall, st, bears,...|
# | 3|Carlyle Looks Tow...|[carlyle, looks, ...|[carlyle, looks, ...|
# | 3|Oil and Economy C...|[oil, and, econom...|[oil, economy, cl...|
# | 3|Iraq Halts Oil Ex...|[iraq, halts, oil...|[iraq, halts, oil...|
# | 3|Oil prices soar t...|[oil, prices, soa...|[oil, prices, soa...|
# +-----+--------------------+--------------------+--------------------+
# only showing top 5 rows
  • HashingTF — Calculating the Term Frequency of the words in the corpus
# Calculate term frequency in each article
hashing_tf = HashingTF(inputCol="filtered",
outputCol="raw_features",
numFeatures=10000)

# adds raw tf features to df
featurized_data = hashing_tf.transform(df)
  • IDF Vectorizer —Inverse document frequency (IDF).

The standard formulation is used: IDF = log((m + 1) / (d(t) + 1)), where m is the total number of documents and d(t) is the number of documents that contain the term t.

# Inverse document frequency
idf = IDF(inputCol="raw_features", outputCol="features")

idf_vectorizer = idf.fit(featurized_data)

# converting text to vectors
rescaled_data = idf_vectorizer.transform(featurized_data)

# top 20 rows
rescaled_data.select("label",'Text', 'words', 'filtered', "features").show()
  • Train Test Split: keep 75% data as the train set and 25% data as a test set
# Split Train/Test data
(train, test) = rescaled_data.randomSplit([0.75, 0.25], seed = 202)
print("Training Dataset Count: " + str(train.count()))
print("Test Dataset Count: " + str(test.count()))

# Output
# Training Dataset Count: 90149
# Test Dataset Count: 29851
  • LogisticRegression — Simplest and power ml algorithm for text classification. Algorithms in the Pyspark ML library support multi-class problems by passing the family parameter as ‘multinomial’ while creating a model object.
# model object
lr = LogisticRegression(featuresCol='features',
labelCol='label',
family="multinomial",
regParam=0.3,
elasticNetParam=0,
maxIter=50)

# train model with default parameters
lrModel = lr.fit(train)

# get predictions for test set
predictions = lrModel.transform(test)

# show top 20 predictions
predictions.select("Text", 'probability','prediction', 'label').show()

# Output
# +--------------------+--------------------+----------+-----+
# | Text| probability|prediction|label|
# +--------------------+--------------------+----------+-----+
# | #39;Batman #39; ...|[4.79798540233463...| 1.0| 1|
# | #39;Black boxes ...|[6.08123465944845...| 1.0| 1|
# | #39;Deserter #39...|[3.25929743141170...| 1.0| 1|
# | #39;Patience, pe...|[1.50809323657597...| 1.0| 1|
# | #39;She won #39;...|[2.22525102692627...| 1.0| 1|
# | #39;Suspect Pack...|[3.75379098074928...| 4.0| 1|
# | #39;The Scream #...|[4.44907744471914...| 1.0| 1|
# | #39;The world is...|[3.50140661367225...| 1.0| 1|
# |"Australian oppos...|[1.04995173599668...| 1.0| 1|
# |"Europeans expect...|[2.47196376832162...| 4.0| 1|
# |"US dismisses Nor...|[8.76768106766652...| 1.0| 1|
# |'Alien Vs. Predat...|[6.05476093864577...| 4.0| 1|
# |'Boots' Takes Gol...|[3.44151333177751...| 2.0| 1|
# |'Cold-Blooded Kil...|[1.54645775278094...| 1.0| 1|
# |'Distressed' That...|[1.68591918798567...| 1.0| 1|
# |'Dozens killed' i...|[1.20363007026082...| 1.0| 1|
# |'Hawaii Five-0' M...|[6.97588872950026...| 4.0| 1|
# |'Mock executions'...|[3.10698789262081...| 1.0| 1|
# |'On Death and Dyi...|[7.13012837580222...| 4.0| 1|
# |'Potential Develo...|[9.20192968773768...| 1.0| 1|
# +--------------------+--------------------+----------+-----+
# only showing top 20 rows

Model Evaluation

  • Accuracy — Achieved 90% of test accuracy
# to evalute model
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")

# print test accuracy
print("Test-set Accuracy is : ", evaluator.evaluate(predictions))

# Output
# Test-set Accuracy is : 0.8964532714370209
  • Confusion matrix — It is the visual representation of the Actual VS Predicted values. It measures the performance of our Machine Learning classification model.
labels = ["World", "Sports", "Business","Science"]

# important: need to cast to float type, and order by prediction, else it won't work
preds_and_labels = predictions.select(['prediction','label']) \
.withColumn('label', col('label') \
.cast(FloatType())) \
.orderBy('prediction')
# generate metrics
metrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))

# figure object
_ = plt.figure(figsize=(7, 7))

# plot confusion matrix
sns.heatmap(metrics.confusionMatrix().toArray(),
cmap='viridis',
annot=True,fmt='0',
cbar=False,
xticklabels=labels,
yticklabels=labels)
plt.show()
Confusion matrix

ML Pipeline

we can convert all the above steps into a single pipeline to make it easy to run and more readable. Pyspark provides a Pipeline API to build the pipeline using a series of transformers and a final estimator to get predictions (refer to figure 1).

# load dataset
df = spark.read.csv("train.csv", inferSchema=True, header=True)

# Renaming 'Class Index' col to 'label'
df = df.withColumnRenamed('Class Index', 'label')

# Add a new column 'Text' by concatinating 'Title' and 'Description'
df = df.withColumn("Text", concat_ws(" ", "Title", 'Description'))

# Select new text feature and labels
df = df.select('label', 'Text')

# tokenizer
tokenizer = RegexTokenizer(inputCol="Text", outputCol="words", pattern="\\W")

# stopwords
stopwords_remover = StopWordsRemover(inputCol="words", outputCol="filtered")

# term frequency
hashing_tf = HashingTF(inputCol="filtered",
outputCol="raw_features",
numFeatures=10000)

# Inverse Document Frequency - vectorizer
idf = IDF(inputCol="raw_features", outputCol="features")

# model
lr = LogisticRegression(featuresCol='features',
labelCol='label',
family="multinomial",
regParam=0.3,
elasticNetParam=0,
maxIter=50)


# Put everything in pipeline
pipeline = Pipeline(stages=[tokenizer,
stopwords_remover,
hashing_tf,
idf,
lr])

# Fit the pipeline to training documents.
pipelineFit = pipeline.fit(df)

# transform add train
dataset = pipelineFit.transform(df)

# show top 10 predictions
dataset.show(10)
pipeline predictions

Conclusion

In this tutorial, we developed a text classification model on a news dataset using the Pyspark library.

We have used pyspark transformers to vectorize our text. We trained a Logistic Regression model and got 90% of test accuracy and plot confusion matrix. In the end, we put all the preprocessing and modeling steps in a pipeline to run our predictions easily.

Let’s connect: Twitter | LinkedIn | GitHub

My other tutorials:

--

--