Spaces:
Sleeping
Sleeping
File size: 4,595 Bytes
f484830 f0bc9a8 25fddff f0bc9a8 f484830 b0119a6 f484830 f0bc9a8 f484830 f0bc9a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import numpy as np
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_rdmol
from datasets import load_dataset
def get_tox21_split(token, cvfold=None):
ds = load_dataset("ml-jku/tox21", token=token)
train_df = ds["train"].to_pandas()
val_df = ds["validation"].to_pandas()
if cvfold is None:
return {
"train": train_df,
"validation": val_df
}
combined_df = pd.concat([train_df, val_df], ignore_index=True)
cvfold = float(cvfold)
# create new splits
cvfold = float(cvfold)
train_df = combined_df[combined_df.CVfold != cvfold]
val_df = combined_df[combined_df.CVfold == cvfold]
# exclude train mols that occur in the validation split
val_inchikeys = set(val_df["inchikey"])
train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
"""Create cleaned RDKit Mol objects from SMILES.
Returns (list of mols, mask of valid mols).
"""
clean_mol_mask = []
mols = []
# Standardizer components
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
clean_mol_mask.append(False)
continue
# Cleanup and canonicalize
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# Recompute canonical SMILES & reload
can_smi = Chem.MolToSmiles(mol)
mol = Chem.MolFromSmiles(can_smi)
if mol is not None:
mols.append(mol)
clean_mol_mask.append(True)
else:
clean_mol_mask.append(False)
except Exception as e:
print(f"Failed to standardize {smi}: {e}")
clean_mol_mask.append(False)
return mols, np.array(clean_mol_mask, dtype=bool)
class Tox21Dataset(InMemoryDataset):
def __init__(self, dataframe):
super().__init__()
data_list = []
# Clean molecules & filter dataframe
mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
self.clean_mask = torch.tensor(clean_mask, dtype=torch.bool)
drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
labels_df = dataframe.drop(columns=drop_cols)
numeric_labels = labels_df.apply(pd.to_numeric, errors="coerce").fillna(0.0)
self.all_labels = torch.tensor(numeric_labels.values, dtype=torch.float)
self.all_label_masks = torch.tensor(~labels_df.isna().values, dtype=torch.bool)
dataframe = dataframe[clean_mask].reset_index(drop=True)
# Now mols and dataframe are aligned, so we can zip
for mol, (_, row) in zip(mols, dataframe.iterrows()):
try:
data = from_rdmol(mol)
# Extract labels as a pandas Series
labels = row.drop(drop_cols)
# Mask for valid labels
mask = ~labels.isna()
# Explicit numeric conversion, replaces NaN with 0.0 safely
labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values
# Convert to tensors
y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)
data.y = y
data.mask = m
data_list.append(data)
except Exception as e:
print(f"Skipping molecule {row['smiles']} due to error: {e}")
# Collate into dataset
self.data, self.slices = self.collate(data_list)
def get_graph_datasets(token):
"""returns an InMemoryDataset that can be used in dataloaders
Args:
filepath (str): the filepath of the data csv
Returns:
Tox21Dataset: dataset for dataloaders
"""
datasets = get_tox21_split(token, cvfold=4)
train_df, val_df = datasets["train"], datasets["validation"]
train_dataset = Tox21Dataset(train_df)
val_dataset = Tox21Dataset(val_df)
return train_dataset, val_dataset |