File size: 3,742 Bytes
f0bc9a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_auc_score

def masked_bce_loss(logits, labels, mask):
    """
    logits: [batch_size, num_classes] (raw outputs)
    labels: [batch_size, num_classes] (0/1 with filler)
    mask:   [batch_size, num_classes] (True if label is valid)
    """
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    loss_raw = criterion(logits, labels)
    loss = (loss_raw * mask.float()).sum() / mask.float().sum()
    return loss

def train_model(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)

        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)  # [num_graphs, num_classes]

        loss = masked_bce_loss(out, batch.y, batch.mask)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = masked_bce_loss(out, batch.y, batch.mask)
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)


@torch.no_grad()
def compute_roc_auc(model, loader, device):
    model.eval()
    y_true, y_pred, y_mask = [], [], []

    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)

        # Store predictions (sigmoid → probabilities)
        y_pred.append(torch.sigmoid(out).cpu())
        y_true.append(batch.y.cpu())
        y_mask.append(batch.mask.cpu())

    # Concatenate across all batches
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    y_mask = torch.cat(y_mask, dim=0).numpy()

    auc_list = []
    for i in range(y_true.shape[1]):  # per label
        mask_i = y_mask[:, i].astype(bool)
        if mask_i.sum() > 0:  # at least one valid label
            try:
                auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
                auc_list.append(auc)
            except ValueError:
                # happens if only one class present (all 0 or all 1)
                pass

    return np.mean(auc_list) if len(auc_list) > 0 else float("nan")

@torch.no_grad()
def compute_roc_auc_avg_and_per_class(model, loader, device):
    model.eval()
    y_true, y_pred, y_mask = [], [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)

            # Store predictions (sigmoid → probabilities)
            y_pred.append(torch.sigmoid(out).cpu())
            y_true.append(batch.y.cpu())
            y_mask.append(batch.mask.cpu())

    # Concatenate across all batches
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    y_mask = torch.cat(y_mask, dim=0).numpy()

    # Compute AUC per class
    auc_list = []
    for i in range(y_true.shape[1]):
        mask_i = y_mask[:, i].astype(bool)
        if mask_i.sum() > 0:
            try:
                auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
            except ValueError:
                auc = np.nan  # in case only one class present
        else:
            auc = np.nan
        auc_list.append(auc)

    # Convert to numpy array for easier manipulation
    auc_array = np.array(auc_list, dtype=np.float32)
    mean_auc = np.nanmean(auc_array)  # overall mean ignoring NaNs

    # Return both per-class and mean
    return auc_array, mean_auc