import torch import torch.nn as nn import numpy as np from sklearn.metrics import roc_auc_score def masked_bce_loss(logits, labels, mask): """ logits: [batch_size, num_classes] (raw outputs) labels: [batch_size, num_classes] (0/1 with filler) mask: [batch_size, num_classes] (True if label is valid) """ criterion = nn.BCEWithLogitsLoss(reduction="none") loss_raw = criterion(logits, labels) loss = (loss_raw * mask.float()).sum() / mask.float().sum() return loss def train_model(model, loader, optimizer, device): model.train() total_loss = 0 for batch in loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes] loss = masked_bce_loss(out, batch.y, batch.mask) loss.backward() optimizer.step() total_loss += loss.item() * batch.num_graphs return total_loss / len(loader.dataset) @torch.no_grad() def evaluate(model, loader, device): model.eval() total_loss = 0 for batch in loader: batch = batch.to(device) out = model(batch.x, batch.edge_index, batch.batch) loss = masked_bce_loss(out, batch.y, batch.mask) total_loss += loss.item() * batch.num_graphs return total_loss / len(loader.dataset) @torch.no_grad() def compute_roc_auc(model, loader, device): model.eval() y_true, y_pred, y_mask = [], [], [] for batch in loader: batch = batch.to(device) out = model(batch.x, batch.edge_index, batch.batch) # Store predictions (sigmoid → probabilities) y_pred.append(torch.sigmoid(out).cpu()) y_true.append(batch.y.cpu()) y_mask.append(batch.mask.cpu()) # Concatenate across all batches y_true = torch.cat(y_true, dim=0).numpy() y_pred = torch.cat(y_pred, dim=0).numpy() y_mask = torch.cat(y_mask, dim=0).numpy() auc_list = [] for i in range(y_true.shape[1]): # per label mask_i = y_mask[:, i].astype(bool) if mask_i.sum() > 0: # at least one valid label try: auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i]) auc_list.append(auc) except ValueError: # happens if only one class present (all 0 or all 1) pass return np.mean(auc_list) if len(auc_list) > 0 else float("nan") @torch.no_grad() def compute_roc_auc_avg_and_per_class(model, loader, device): model.eval() y_true, y_pred, y_mask = [], [], [] with torch.no_grad(): for batch in loader: batch = batch.to(device) out = model(batch.x, batch.edge_index, batch.batch) # Store predictions (sigmoid → probabilities) y_pred.append(torch.sigmoid(out).cpu()) y_true.append(batch.y.cpu()) y_mask.append(batch.mask.cpu()) # Concatenate across all batches y_true = torch.cat(y_true, dim=0).numpy() y_pred = torch.cat(y_pred, dim=0).numpy() y_mask = torch.cat(y_mask, dim=0).numpy() # Compute AUC per class auc_list = [] for i in range(y_true.shape[1]): mask_i = y_mask[:, i].astype(bool) if mask_i.sum() > 0: try: auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i]) except ValueError: auc = np.nan # in case only one class present else: auc = np.nan auc_list.append(auc) # Convert to numpy array for easier manipulation auc_array = np.array(auc_list, dtype=np.float32) mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs # Return both per-class and mean return auc_array, mean_auc