DistilRoBERTa for SMS Spam Detection (FP32)
This model is a fine-tuned version of distilroberta-base for binary text classification on the classic UCI SMS Spam Collection dataset. It is designed to distinguish between legitimate messages ("ham") and spam messages.
The model achieves an outstanding 99.64% accuracy on the test set, with perfect (100%) precision on both spam and ham predictions, making it highly reliable at not misclassifying legitimate messages.
This repository contains the full-precision (FP32) PyTorch model, optimized for maximum accuracy. For a production-ready, quantized ONNX version suitable for mobile deployment, please see the sibling repository here.
Model Description
- Model type:
distilroberta-basefine-tuned for sequence classification. - Language(s): English
- License: MIT
- Fine-tuned from: distilroberta-base
- Dataset: UCI SMS Spam Collection
How to Use
You can use this model directly with the transformers library pipeline.
from transformers import pipeline
# Load the model from the Hub
spam_classifier = pipeline(
"text-classification",
model="SharpWoofer/distilroberta-sms-spam-detector"
)
# Test with a spam message
result_spam = spam_classifier("URGENT! You have won a 1 week FREE membership!")
print(result_spam)
# >> [{'label': 'SPAM', 'score': 0.999...}]
# Test with a ham message
result_ham = spam_classifier("Hey, just checking in about our meeting tomorrow.")
print(result_ham)
# >> [{'label': 'HAM', 'score': 0.999...}]
Training Procedure
Dataset & Preprocessing
The model was trained on the UCI SMS Spam Collection dataset, which is heavily imbalanced (4825 ham vs. 747 spam). To ensure robust performance, the data was split into three sets:
- Train: 4456 examples (80%)
- Validation: 558 examples (10%)
- Test: 558 examples (10%)
Fine-Tuning & Class Imbalance
A key challenge was the class imbalance. To address this, a custom WeightedLossTrainer was implemented in PyTorch. This trainer applies a higher penalty when the model misclassifies the minority class (spam), forcing it to pay more attention to spam examples. The calculated class weights were approximately [0.58] for 'ham' and [3.79] for 'spam'.
Hyperparameters:
- Optimizer: AdamW
- Learning Rate: 5e-5 (with linear decay and 500 warmup steps)
- Epochs: 3
- Batch Size: 16
- Weight Decay: 0.01
The best model checkpoint was selected based on the highest F1-score on the validation set at the end of Epoch 3.
Evaluation Results
The model was evaluated on the held-out test set of 558 examples. The performance is exceptional, particularly its precision.
Final Test Set Metrics:
- Accuracy: 99.64%
- F1 (Weighted): 0.9964
- Precision (Weighted): 0.9964
- Recall (Weighted): 0.9964
Classification Report:
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| HAM | 1.00 | 1.00 | 1.00 | 480 |
| SPAM | 1.00 | 0.97 | 0.99 | 78 |
| Overall | 1.00 | 1.00 | 1.00 | 558 |
This report highlights the model's key strength: it did not misclassify a single legitimate message as spam in the entire test set, making it very safe for production use.
- Downloads last month
- 4
Model tree for SharpWoofer/distilroberta-sms-spam-detector
Base model
distilbert/distilroberta-baseDataset used to train SharpWoofer/distilroberta-sms-spam-detector
Evaluation results
- F1 (Weighted) on ucirvine/sms_spam (test split)test set self-reported0.996
- Accuracy on ucirvine/sms_spam (test split)test set self-reported0.996
- Precision (Weighted) on ucirvine/sms_spam (test split)test set self-reported0.996
- Recall (Weighted) on ucirvine/sms_spam (test split)test set self-reported0.996