Prompt-Injection-Classifier / train_prompt_injection_detector.py
Tameem7's picture
fix eval speed
849ca5b
#!/usr/bin/env python3
"""
Project #1: Prompt Injection Detection Classifier
Train a binary classifier to detect safe (0) vs unsafe (1) prompts
using the Aegis AI Content Safety Dataset 2.0.
Steps:
1. Load dataset with prompt and prompt_label fields
2. Convert labels: "safe" → 0, "unsafe" → 1
3. Create train/validation split (since dataset is for "testing")
4. Train a sequence classification model
5. Evaluate on test split
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
TrainingArguments,
Trainer,
TrainerCallback,
)
from load_aegis_dataset import load_aegis_dataset
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
def compute_metrics(eval_pred):
"""Compute classification metrics."""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, predictions, average='weighted', zero_division=0
)
accuracy = accuracy_score(labels, predictions)
# Confusion matrix
cm = confusion_matrix(labels, predictions)
return {
'accuracy': accuracy,
'f1': f1,
'precision': precision,
'recall': recall,
'confusion_matrix': cm.tolist(),
}
def tokenize_function(examples, tokenizer):
"""Tokenize the prompts."""
return tokenizer(
examples["prompt"],
truncation=True,
padding="max_length",
max_length=512,
)
class TestLossCallback(TrainerCallback):
"""Callback to track test loss after each epoch."""
def __init__(self, test_dataset, trainer):
self.test_dataset = test_dataset
self.trainer = trainer
self.test_losses = []
self.test_epochs = []
def on_epoch_end(self, args, state, control, **kwargs):
"""Evaluate on test set after each epoch."""
if self.test_dataset is not None:
test_results = self.trainer.evaluate(eval_dataset=self.test_dataset)
if "eval_loss" in test_results:
self.test_losses.append(test_results["eval_loss"])
self.test_epochs.append(state.epoch)
logger.info(f"Epoch {state.epoch}: Test Loss = {test_results['eval_loss']:.4f}")
def main():
parser = argparse.ArgumentParser(description="Train prompt injection detection classifier")
parser.add_argument(
"--model-name",
type=str,
default="distilbert-base-uncased",
help="Base model for classification (distilbert-base-uncased, bert-base-uncased, roberta-base)"
)
parser.add_argument(
"--output-dir",
type=str,
default="./prompt-injection-detector",
help="Directory to save the trained model"
)
parser.add_argument(
"--num-epochs",
type=int,
default=3,
help="Number of training epochs"
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="Training batch size"
)
parser.add_argument(
"--learning-rate",
type=float,
default=5e-5,
help="Learning rate"
)
parser.add_argument(
"--test-size",
type=float,
default=0.1,
help="Fraction of data to use for validation (rest for training)"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility"
)
args = parser.parse_args()
logger.info("=" * 60)
logger.info("Project #1: Prompt Injection Detection Classifier")
logger.info("=" * 60)
logger.info(f"Model: {args.model_name}")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}")
logger.info("=" * 60)
# Step 1: Load dataset (train/validation/test if available)
logger.info("Step 1: Loading Aegis dataset splits...")
dataset = load_aegis_dataset()
if isinstance(dataset, DatasetDict):
logger.info(f"Available splits: {list(dataset.keys())}")
train_dataset = dataset.get("train")
val_dataset = dataset.get("validation") or dataset.get("val")
test_dataset = dataset.get("test")
elif isinstance(dataset, Dataset):
logger.warning("Dataset returned a single split. Treating as 'train'.")
train_dataset = dataset
val_dataset = None
test_dataset = None
else:
raise ValueError("Unexpected dataset type returned from load_aegis_dataset.")
if train_dataset is None:
raise ValueError("Train split not found in dataset.")
logger.info(f"Train split size: {len(train_dataset)}")
logger.info(f"Train fields: {train_dataset.column_names}")
logger.info(f"Train sample: {train_dataset[0]}")
if val_dataset is not None:
logger.info(f"Validation split size: {len(val_dataset)}")
else:
logger.info("Validation split not found; will create from train split.")
if test_dataset is not None:
logger.info(f"Test split size: {len(test_dataset)}")
else:
logger.info("Test split not found; will fall back to validation split for final evaluation if needed.")
# Step 2: Verify label mapping and create validation split if missing
logger.info("\nStep 2: Verifying label mapping and preparing splits...")
unique_labels = set(train_dataset["prompt_label"])
logger.info(f"Unique labels: {unique_labels}")
assert unique_labels == {0, 1}, f"Expected labels {{0, 1}}, got {unique_labels}"
# Count safe vs unsafe
safe_count = sum(1 for label in train_dataset["prompt_label"] if label == 0)
unsafe_count = sum(1 for label in train_dataset["prompt_label"] if label == 1)
logger.info(f"Safe prompts: {safe_count}, Unsafe prompts: {unsafe_count}")
if val_dataset is None:
logger.info("Creating validation split from train data...")
split_dataset = train_dataset.train_test_split(
test_size=args.test_size,
shuffle=True,
seed=args.seed
)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
logger.info(f"Final train samples: {len(train_dataset)}")
logger.info(f"Final validation samples: {len(val_dataset)}")
# Step 3: Load model and tokenizer
logger.info(f"\nStep 3: Loading model and tokenizer: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=2,
)
# Step 4: Tokenize datasets
logger.info("\nStep 4: Tokenizing datasets...")
tokenize_fn = lambda examples: tokenize_function(examples, tokenizer)
train_tokenized = train_dataset.map(
tokenize_fn,
batched=True,
remove_columns=["prompt"], # Keep prompt_label for labels
)
val_tokenized = val_dataset.map(
tokenize_fn,
batched=True,
remove_columns=["prompt"],
)
# Rename prompt_label to labels for Trainer
train_tokenized = train_tokenized.rename_column("prompt_label", "labels")
val_tokenized = val_tokenized.rename_column("prompt_label", "labels")
# Set format for PyTorch
train_tokenized.set_format("torch")
val_tokenized.set_format("torch")
# Prepare test dataset if available
test_tokenized = None
if test_dataset is not None:
test_tokenized = test_dataset.map(
tokenize_fn,
batched=True,
remove_columns=["prompt"],
)
test_tokenized = test_tokenized.rename_column("prompt_label", "labels")
test_tokenized.set_format("torch")
# Step 5: Set up training
logger.info("\nStep 5: Setting up training...")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
weight_decay=0.01,
warmup_steps=500,
logging_dir=str(output_dir / "logs"),
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
save_total_limit=3,
fp16=False, # Set to True if you have GPU
report_to="none",
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized,
eval_dataset=val_tokenized,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Add callback to track test loss if test dataset is available
test_callback = None
if test_tokenized is not None:
test_callback = TestLossCallback(test_tokenized, trainer)
trainer.add_callback(test_callback)
# Step 6: Train
logger.info("\nStep 6: Training classifier...")
trainer.train()
# Extract training history for plotting
train_losses = []
train_epochs = []
val_losses = []
val_epochs = []
for log_entry in trainer.state.log_history:
if "loss" in log_entry and "epoch" in log_entry:
train_losses.append(log_entry["loss"])
train_epochs.append(log_entry["epoch"])
elif "eval_loss" in log_entry and "epoch" in log_entry:
val_losses.append(log_entry["eval_loss"])
val_epochs.append(log_entry["epoch"])
# Step 7: Evaluate on validation set
logger.info("\nStep 7: Evaluating on validation set...")
eval_results = trainer.evaluate()
logger.info("\nValidation Results:")
for key, value in eval_results.items():
if key != "confusion_matrix":
logger.info(f" {key}: {value:.4f}")
else:
logger.info(f" {key}:")
logger.info(" " + "\n ".join(str(row) for row in value))
# Step 8: Test on test split (if available)
logger.info("\nStep 8: Testing on test split...")
if test_tokenized is not None:
logger.info(f"Test dataset found with {len(test_dataset)} samples.")
# Get test losses from callback if available
if test_callback and test_callback.test_losses:
test_losses = test_callback.test_losses
test_epochs = test_callback.test_epochs
logger.info(f"Test losses tracked over {len(test_losses)} epochs via callback.")
else:
# Fallback: evaluate final model on test set
test_results = trainer.evaluate(eval_dataset=test_tokenized)
test_losses = [test_results["eval_loss"]]
test_epochs = [args.num_epochs]
logger.info("Evaluated final model on test set.")
# Final test evaluation
test_results = trainer.evaluate(eval_dataset=test_tokenized)
logger.info("\nFinal Test Results:")
for key, value in test_results.items():
if key != "confusion_matrix":
logger.info(f" {key}: {value:.4f}")
else:
logger.info(f" {key}:")
logger.info(" " + "\n ".join(str(row) for row in value))
else:
logger.warning("Test split not found; using validation losses for plotting.")
# Use validation losses as test losses for plotting
test_losses = val_losses
test_epochs = val_epochs
# Step 9: Plot training and test loss
logger.info("\nStep 9: Plotting training and test loss...")
plt.figure(figsize=(10, 6))
if train_losses and train_epochs:
plt.plot(train_epochs, train_losses, 'b-o', label='Train Loss', linewidth=2, markersize=6)
if test_losses and test_epochs:
plt.plot(test_epochs, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training and Test Loss Over Epochs', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
# Save plot
plot_path = output_dir / "loss_plot.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
logger.info(f"Loss plot saved to: {plot_path}")
plt.close()
# Step 10: Save model
logger.info(f"\nStep 10: Saving model to {output_dir}...")
trainer.save_model()
tokenizer.save_pretrained(str(output_dir))
logger.info("=" * 60)
logger.info("Training complete!")
logger.info(f"Model saved to: {output_dir}")
logger.info(f"Loss plot saved to: {plot_path}")
logger.info("=" * 60)
if __name__ == "__main__":
main()