Tameem7's picture
fix eval speed
849ca5b
from load_aegis_dataset import load_aegis_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer
from datasets import DatasetDict
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
accuracy = accuracy_score(labels, preds)
cm = confusion_matrix(labels, preds)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'confusion_matrix': cm.tolist()
}
model_dir = 'prompt-injection-detector/checkpoint-5628'
print(f'Loading model from {model_dir}')
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
print('Loading dataset...')
ds = load_aegis_dataset()
if not isinstance(ds, DatasetDict) or 'test' not in ds:
raise RuntimeError('Test split not available in dataset.')
test_ds = ds['test']
print(f'Test samples: {len(test_ds)}')
def tokenize(batch):
return tokenizer(batch['prompt'], truncation=True, padding='max_length', max_length=512)
test_tok = test_ds.map(tokenize, batched=True, remove_columns=['prompt'])
test_tok = test_tok.rename_column('prompt_label', 'labels')
test_tok.set_format('torch')
trainer = Trainer(model=model, tokenizer=tokenizer, compute_metrics=compute_metrics)
print('Evaluating...')
results = trainer.evaluate(eval_dataset=test_tok)
print('Test metrics:')
for k, v in results.items():
print(f' {k}: {v}')