Prompt-Injection-Classifier / load_aegis_dataset.py
Tameem7's picture
Add application file
5a27052
#!/usr/bin/env python3
"""
Utility for loading Nvidia's Aegis AI Content Safety Dataset 2.0 with
the exact fields needed for prompt injection detection experiments.
Only the `prompt` text and the normalized `prompt_label` fields are kept.
Labels are mapped to integers: `safe -> 0`, `unsafe -> 1`.
"""
from __future__ import annotations
from typing import Dict, Optional
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset
DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0"
LABEL_MAP = {"safe": 0, "unsafe": 1}
SELECTED_COLUMNS = ["prompt", "prompt_label"]
def _map_labels(batch: Dict[str, list]) -> Dict[str, list]:
"""Batched mapping function that converts string labels to ints."""
batch["prompt_label"] = [LABEL_MAP[label] for label in batch["prompt_label"]]
return batch
def _prepare_split(ds: Dataset) -> Dataset:
"""
Keep only the required columns and normalize labels for a single split.
"""
subset = ds.select_columns(SELECTED_COLUMNS)
return subset.map(_map_labels, batched=True)
def load_aegis_dataset(
split: Optional[str] = None,
streaming: bool = False,
) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict:
"""
Load the Aegis dataset with normalized `prompt_label`.
Args:
split: Optional split name ("train", "validation", "test", etc.).
streaming: Whether to stream the data instead of downloading it locally.
Returns:
A processed Dataset (if split is provided) or DatasetDict containing only
`prompt` and integer `prompt_label` columns.
"""
dataset = load_dataset(DATASET_NAME, split=split, streaming=streaming)
if split is not None:
if streaming:
# IterableDataset does not support select_columns/map the same way.
def generator():
for row in dataset:
yield {
"prompt": row["prompt"],
"prompt_label": LABEL_MAP[row["prompt_label"]],
}
return IterableDataset.from_generator(generator)
return _prepare_split(dataset)
# Multiple splits.
if streaming:
processed = {}
for split_name, iterable in dataset.items():
def make_iter(it):
def generator():
for row in it:
yield {
"prompt": row["prompt"],
"prompt_label": LABEL_MAP[row["prompt_label"]],
}
return IterableDataset.from_generator(generator)
processed[split_name] = make_iter(iterable)
return IterableDatasetDict(processed)
return DatasetDict({split_name: _prepare_split(split_ds) for split_name, split_ds in dataset.items()})
if __name__ == "__main__":
processed = load_aegis_dataset()
for split_name, split_ds in processed.items():
print(f"{split_name}: {len(split_ds)} samples")
print(split_ds[0])