import numpy as np import torch import pandas as pd from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit import Chem from torch_geometric.data import InMemoryDataset from torch_geometric.utils import from_rdmol from datasets import load_dataset def get_tox21_split(token, cvfold=None): ds = load_dataset("ml-jku/tox21", token=token) train_df = ds["train"].to_pandas() val_df = ds["validation"].to_pandas() if cvfold is None: return { "train": train_df, "validation": val_df } combined_df = pd.concat([train_df, val_df], ignore_index=True) cvfold = float(cvfold) # create new splits cvfold = float(cvfold) train_df = combined_df[combined_df.CVfold != cvfold] val_df = combined_df[combined_df.CVfold == cvfold] # exclude train mols that occur in the validation split val_inchikeys = set(val_df["inchikey"]) train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)] return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)} def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]: """Create cleaned RDKit Mol objects from SMILES. Returns (list of mols, mask of valid mols). """ clean_mol_mask = [] mols = [] # Standardizer components cleaner = rdMolStandardize.CleanupParameters() tautomer_enumerator = rdMolStandardize.TautomerEnumerator() for smi in smiles: try: mol = Chem.MolFromSmiles(smi) if mol is None: clean_mol_mask.append(False) continue # Cleanup and canonicalize mol = rdMolStandardize.Cleanup(mol, cleaner) mol = tautomer_enumerator.Canonicalize(mol) # Recompute canonical SMILES & reload can_smi = Chem.MolToSmiles(mol) mol = Chem.MolFromSmiles(can_smi) if mol is not None: mols.append(mol) clean_mol_mask.append(True) else: clean_mol_mask.append(False) except Exception as e: print(f"Failed to standardize {smi}: {e}") clean_mol_mask.append(False) return mols, np.array(clean_mol_mask, dtype=bool) class Tox21Dataset(InMemoryDataset): def __init__(self, dataframe): super().__init__() data_list = [] # Clean molecules & filter dataframe mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist()) self.clean_mask = torch.tensor(clean_mask, dtype=torch.bool) drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"] labels_df = dataframe.drop(columns=drop_cols) numeric_labels = labels_df.apply(pd.to_numeric, errors="coerce").fillna(0.0) self.all_labels = torch.tensor(numeric_labels.values, dtype=torch.float) self.all_label_masks = torch.tensor(~labels_df.isna().values, dtype=torch.bool) dataframe = dataframe[clean_mask].reset_index(drop=True) # Now mols and dataframe are aligned, so we can zip for mol, (_, row) in zip(mols, dataframe.iterrows()): try: data = from_rdmol(mol) # Extract labels as a pandas Series labels = row.drop(drop_cols) # Mask for valid labels mask = ~labels.isna() # Explicit numeric conversion, replaces NaN with 0.0 safely labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values # Convert to tensors y = torch.tensor(labels, dtype=torch.float).unsqueeze(0) m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0) data.y = y data.mask = m data_list.append(data) except Exception as e: print(f"Skipping molecule {row['smiles']} due to error: {e}") # Collate into dataset self.data, self.slices = self.collate(data_list) def get_graph_datasets(token): """returns an InMemoryDataset that can be used in dataloaders Args: filepath (str): the filepath of the data csv Returns: Tox21Dataset: dataset for dataloaders """ datasets = get_tox21_split(token, cvfold=4) train_df, val_df = datasets["train"], datasets["validation"] train_dataset = Tox21Dataset(train_df) val_dataset = Tox21Dataset(val_df) return train_dataset, val_dataset