Spaces:
Sleeping
Sleeping
File size: 3,742 Bytes
f0bc9a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_auc_score
def masked_bce_loss(logits, labels, mask):
"""
logits: [batch_size, num_classes] (raw outputs)
labels: [batch_size, num_classes] (0/1 with filler)
mask: [batch_size, num_classes] (True if label is valid)
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
loss_raw = criterion(logits, labels)
loss = (loss_raw * mask.float()).sum() / mask.float().sum()
return loss
def train_model(model, loader, optimizer, device):
model.train()
total_loss = 0
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
loss = masked_bce_loss(out, batch.y, batch.mask)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
loss = masked_bce_loss(out, batch.y, batch.mask)
total_loss += loss.item() * batch.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def compute_roc_auc(model, loader, device):
model.eval()
y_true, y_pred, y_mask = [], [], []
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
# Store predictions (sigmoid → probabilities)
y_pred.append(torch.sigmoid(out).cpu())
y_true.append(batch.y.cpu())
y_mask.append(batch.mask.cpu())
# Concatenate across all batches
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
y_mask = torch.cat(y_mask, dim=0).numpy()
auc_list = []
for i in range(y_true.shape[1]): # per label
mask_i = y_mask[:, i].astype(bool)
if mask_i.sum() > 0: # at least one valid label
try:
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
auc_list.append(auc)
except ValueError:
# happens if only one class present (all 0 or all 1)
pass
return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
@torch.no_grad()
def compute_roc_auc_avg_and_per_class(model, loader, device):
model.eval()
y_true, y_pred, y_mask = [], [], []
with torch.no_grad():
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
# Store predictions (sigmoid → probabilities)
y_pred.append(torch.sigmoid(out).cpu())
y_true.append(batch.y.cpu())
y_mask.append(batch.mask.cpu())
# Concatenate across all batches
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
y_mask = torch.cat(y_mask, dim=0).numpy()
# Compute AUC per class
auc_list = []
for i in range(y_true.shape[1]):
mask_i = y_mask[:, i].astype(bool)
if mask_i.sum() > 0:
try:
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
except ValueError:
auc = np.nan # in case only one class present
else:
auc = np.nan
auc_list.append(auc)
# Convert to numpy array for easier manipulation
auc_array = np.array(auc_list, dtype=np.float32)
mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
# Return both per-class and mean
return auc_array, mean_auc |