#!/usr/bin/env python3 """ Utility for loading Nvidia's Aegis AI Content Safety Dataset 2.0 with the exact fields needed for prompt injection detection experiments. Only the `prompt` text and the normalized `prompt_label` fields are kept. Labels are mapped to integers: `safe -> 0`, `unsafe -> 1`. """ from __future__ import annotations from typing import Dict, Optional from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0" LABEL_MAP = {"safe": 0, "unsafe": 1} SELECTED_COLUMNS = ["prompt", "prompt_label"] def _map_labels(batch: Dict[str, list]) -> Dict[str, list]: """Batched mapping function that converts string labels to ints.""" batch["prompt_label"] = [LABEL_MAP[label] for label in batch["prompt_label"]] return batch def _prepare_split(ds: Dataset) -> Dataset: """ Keep only the required columns and normalize labels for a single split. """ subset = ds.select_columns(SELECTED_COLUMNS) return subset.map(_map_labels, batched=True) def load_aegis_dataset( split: Optional[str] = None, streaming: bool = False, ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: """ Load the Aegis dataset with normalized `prompt_label`. Args: split: Optional split name ("train", "validation", "test", etc.). streaming: Whether to stream the data instead of downloading it locally. Returns: A processed Dataset (if split is provided) or DatasetDict containing only `prompt` and integer `prompt_label` columns. """ dataset = load_dataset(DATASET_NAME, split=split, streaming=streaming) if split is not None: if streaming: # IterableDataset does not support select_columns/map the same way. def generator(): for row in dataset: yield { "prompt": row["prompt"], "prompt_label": LABEL_MAP[row["prompt_label"]], } return IterableDataset.from_generator(generator) return _prepare_split(dataset) # Multiple splits. if streaming: processed = {} for split_name, iterable in dataset.items(): def make_iter(it): def generator(): for row in it: yield { "prompt": row["prompt"], "prompt_label": LABEL_MAP[row["prompt_label"]], } return IterableDataset.from_generator(generator) processed[split_name] = make_iter(iterable) return IterableDatasetDict(processed) return DatasetDict({split_name: _prepare_split(split_ds) for split_name, split_ds in dataset.items()}) if __name__ == "__main__": processed = load_aegis_dataset() for split_name, split_ds in processed.items(): print(f"{split_name}: {len(split_ds)} samples") print(split_ds[0])