Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Utility for loading Nvidia's Aegis AI Content Safety Dataset 2.0 with | |
| the exact fields needed for prompt injection detection experiments. | |
| Only the `prompt` text and the normalized `prompt_label` fields are kept. | |
| Labels are mapped to integers: `safe -> 0`, `unsafe -> 1`. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, Optional | |
| from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset | |
| DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0" | |
| LABEL_MAP = {"safe": 0, "unsafe": 1} | |
| SELECTED_COLUMNS = ["prompt", "prompt_label"] | |
| def _map_labels(batch: Dict[str, list]) -> Dict[str, list]: | |
| """Batched mapping function that converts string labels to ints.""" | |
| batch["prompt_label"] = [LABEL_MAP[label] for label in batch["prompt_label"]] | |
| return batch | |
| def _prepare_split(ds: Dataset) -> Dataset: | |
| """ | |
| Keep only the required columns and normalize labels for a single split. | |
| """ | |
| subset = ds.select_columns(SELECTED_COLUMNS) | |
| return subset.map(_map_labels, batched=True) | |
| def load_aegis_dataset( | |
| split: Optional[str] = None, | |
| streaming: bool = False, | |
| ) -> Dataset | DatasetDict | IterableDataset | IterableDatasetDict: | |
| """ | |
| Load the Aegis dataset with normalized `prompt_label`. | |
| Args: | |
| split: Optional split name ("train", "validation", "test", etc.). | |
| streaming: Whether to stream the data instead of downloading it locally. | |
| Returns: | |
| A processed Dataset (if split is provided) or DatasetDict containing only | |
| `prompt` and integer `prompt_label` columns. | |
| """ | |
| dataset = load_dataset(DATASET_NAME, split=split, streaming=streaming) | |
| if split is not None: | |
| if streaming: | |
| # IterableDataset does not support select_columns/map the same way. | |
| def generator(): | |
| for row in dataset: | |
| yield { | |
| "prompt": row["prompt"], | |
| "prompt_label": LABEL_MAP[row["prompt_label"]], | |
| } | |
| return IterableDataset.from_generator(generator) | |
| return _prepare_split(dataset) | |
| # Multiple splits. | |
| if streaming: | |
| processed = {} | |
| for split_name, iterable in dataset.items(): | |
| def make_iter(it): | |
| def generator(): | |
| for row in it: | |
| yield { | |
| "prompt": row["prompt"], | |
| "prompt_label": LABEL_MAP[row["prompt_label"]], | |
| } | |
| return IterableDataset.from_generator(generator) | |
| processed[split_name] = make_iter(iterable) | |
| return IterableDatasetDict(processed) | |
| return DatasetDict({split_name: _prepare_split(split_ds) for split_name, split_ds in dataset.items()}) | |
| if __name__ == "__main__": | |
| processed = load_aegis_dataset() | |
| for split_name, split_ds in processed.items(): | |
| print(f"{split_name}: {len(split_ds)} samples") | |
| print(split_ds[0]) | |