| import torch |
| import pandas as pd |
| from torch_geometric.data import Data |
| from rdkit import Chem |
| from rdkit.Chem import AllChem, Descriptors |
| from rdkit.Chem import Draw |
| from rdkit import RDLogger |
| RDLogger.DisableLog('rdApp.*') |
|
|
|
|
| |
| def smiles_to_graph(smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return None |
| mol = Chem.AddHs(mol) |
| try: |
| AllChem.EmbedMolecule(mol, AllChem.ETKDG()) |
| AllChem.UFFOptimizeMolecule(mol) |
| except: |
| return None |
|
|
| conf = mol.GetConformer() |
| atoms = mol.GetAtoms() |
| bonds = mol.GetBonds() |
|
|
| node_feats = [] |
| pos = [] |
| edge_index = [] |
| edge_attrs = [] |
|
|
| for atom in atoms: |
| |
| node_feats.append([atom.GetAtomicNum() / 100.0]) |
| position = conf.GetAtomPosition(atom.GetIdx()) |
| pos.append([position.x, position.y, position.z]) |
|
|
| for bond in bonds: |
| start = bond.GetBeginAtomIdx() |
| end = bond.GetEndAtomIdx() |
| edge_index.append([start, end]) |
| edge_index.append([end, start]) |
| bond_type = bond.GetBondType() |
| bond_class = { |
| Chem.BondType.SINGLE: 0, |
| Chem.BondType.DOUBLE: 1, |
| Chem.BondType.TRIPLE: 2, |
| Chem.BondType.AROMATIC: 3 |
| }.get(bond_type, 0) |
| edge_attrs.extend([[bond_class], [bond_class]]) |
|
|
| return Data( |
| x=torch.tensor(node_feats, dtype=torch.float), |
| pos=torch.tensor(pos, dtype=torch.float), |
| edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(), |
| edge_attr=torch.tensor(edge_attrs, dtype=torch.long) |
| ) |
|
|
| |
| def load_goodscents_subset(filepath="../data/leffingwell-goodscent-merge-dataset.csv", |
| index=200, |
| shuffle=True |
| ): |
| |
| df = pd.read_csv(filepath) |
| if shuffle: |
| df = df.sample(frac=1).reset_index(drop=True) |
| if index > 0: |
| df = df.head(index) |
| else: |
| df = df.tail(-1*index) |
| descriptor_cols = df.columns[2:] |
| smiles_list, label_map = [], {} |
| for _, row in df.iterrows(): |
| smiles = row["nonStereoSMILES"] |
| labels = row[descriptor_cols].astype(int).tolist() |
| if smiles and any(labels): |
| smiles_list.append(smiles) |
| label_map[smiles] = labels |
| return smiles_list, label_map, list(descriptor_cols) |
|
|
|
|
|
|
| def sample(model, conditioner, label_vec, constrained=True, steps=1000, debug=True): |
| x_t = torch.randn((10, 1)) |
| pos = torch.randn((10, 3)) |
| edge_index = torch.randint(0, 10, (2, 20)) |
|
|
| for t in reversed(range(1, steps + 1)): |
| cond_embed = conditioner(label_vec.unsqueeze(0)) |
| pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed) |
| bond_logits = temperature_scaled_softmax(bond_logits, temperature=(1/t)) |
| x_t = x_t - pred_x * (1.0 / steps) |
|
|
| x_t = x_t * 100.0 |
| x_t.relu_() |
| atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist() |
| |
| allowed_atoms = [6, 7, 8, 9, 15, 16, 17] |
| bond_logits.relu_() |
| bond_preds = torch.argmax(bond_logits, dim=-1).tolist() |
| if debug: |
| print(f"\tcond_embed: {cond_embed}") |
| print(f"\tx_t: {x_t}") |
| print(f"\tprediction: {x_t}") |
| print(f"\tbond logits: {bond_logits}") |
| print(f"\tatoms: {atom_types}") |
| print(f"\tbonds: {bond_preds}") |
|
|
| mol = Chem.RWMol() |
| idx_map = {} |
| for i, atomic_num in enumerate(atom_types): |
| if constrained and atomic_num not in allowed_atoms: |
| continue |
| try: |
| atom = Chem.Atom(int(atomic_num)) |
| idx_map[i] = mol.AddAtom(atom) |
| except Exception: |
| continue |
|
|
| if len(idx_map) < 2: |
| print("Molecule too small or no valid atoms after filtering.") |
| return "" |
|
|
| bond_type_map = { |
| 0: Chem.BondType.SINGLE, |
| 1: Chem.BondType.DOUBLE, |
| 2: Chem.BondType.TRIPLE, |
| 3: Chem.BondType.AROMATIC |
| } |
|
|
| added = set() |
| for i in range(edge_index.shape[1]): |
| a = int(edge_index[0, i]) |
| b = int(edge_index[1, i]) |
| if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map: |
| try: |
| bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE) |
| mol.AddBond(idx_map[a], idx_map[b], bond_type) |
| added.add((a, b)) |
| except Exception: |
| continue |
| try: |
| mol = mol.GetMol() |
| Chem.SanitizeMol(mol) |
| smiles = Chem.MolToSmiles(mol) |
| img = Draw.MolToImage(mol) |
| img.show() |
| print(f"Atom types: {atom_types}") |
| print(f"Generated SMILES: {smiles}") |
| return smiles |
| except Exception as e: |
| print(f"Sanitization error: {e}") |
| return "" |
|
|
|
|
|
|
| def sample_batch(model, conditioner, label_vec, steps=1000, batch_size=4): |
| mols = [] |
| for _ in range(batch_size): |
| x_t = torch.randn((10, 1)) |
| pos = torch.randn((10, 3)) |
| edge_index = torch.randint(0, 10, (2, 20)) |
|
|
| for t in reversed(range(1, steps + 1)): |
| cond_embed = conditioner(label_vec.unsqueeze(0)) |
| pred_x, bond_logits = model(x_t, pos, edge_index, torch.tensor([t]), cond_embed) |
| x_t = x_t - pred_x * (1.0 / steps) |
|
|
| x_t = x_t * 100.0 |
| x_t.relu_() |
| atom_types = torch.clamp(x_t.round(), 1, 118).int().squeeze().tolist() |
| allowed_atoms = [6, 7, 8, 9, 15, 16, 17] |
| bond_logits.relu_() |
|
|
| mol = Chem.RWMol() |
| idx_map = {} |
| for i, atomic_num in enumerate(atom_types): |
| if atomic_num not in allowed_atoms: |
| continue |
| try: |
| atom = Chem.Atom(int(atomic_num)) |
| idx_map[i] = mol.AddAtom(atom) |
| except Exception: |
| continue |
|
|
| if len(idx_map) < 2: |
| continue |
|
|
| bond_type_map = { |
| 0: Chem.BondType.SINGLE, |
| 1: Chem.BondType.DOUBLE, |
| 2: Chem.BondType.TRIPLE, |
| 3: Chem.BondType.AROMATIC |
| } |
|
|
| added = set() |
| for i in range(edge_index.shape[1]): |
| a = int(edge_index[0, i]) |
| b = int(edge_index[1, i]) |
| if a != b and (a, b) not in added and (b, a) not in added and a in idx_map and b in idx_map: |
| try: |
| bond_type = bond_type_map.get(bond_preds[i], Chem.BondType.SINGLE) |
| mol.AddBond(idx_map[a], idx_map[b], bond_type) |
| added.add((a, b)) |
| except Exception: |
| continue |
|
|
| try: |
| mol = mol.GetMol() |
| Chem.SanitizeMol(mol) |
| mols.append(mol) |
| except Exception: |
| continue |
| return mols |
|
|
|
|
|
|
| |
| def validate_molecule(smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return False, {} |
| return True, {"MolWt": Descriptors.MolWt(mol), "LogP": Descriptors.MolLogP(mol)} |
|
|
|
|
| |
| def test_models(test_model, test_conditioner): |
| good_count: int = 0 |
| index: int = int(4983.0 * 0.2) |
| smiles_list, label_map, label_names = load_goodscents_subset(index=index) |
| dataset = [] |
| test_model.eval() |
| test_conditioner.eval() |
| for smi in smiles_list: |
| g = smiles_to_graph(smi) |
| if g: |
| g.y = torch.tensor(label_map[smi]) |
| dataset.append(g) |
|
|
| for i in range(0, len(dataset)): |
| print(f"Testing molecule {i+1}/{len(dataset)}") |
| data = dataset[i] |
| x_0, pos, edge_index, edge_attr, label_vec = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y |
| new_smiles = sample(test_model, test_conditioner, label_vec=label_vec) |
| print(new_smiles) |
| valid, props = validate_molecule(new_smiles) |
| print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}") |
| if new_smiles != "": |
| good_count += 1 |
|
|
| percent_correct: float = float(good_count) / float(len(dataset)) |
| print(f"Percent correct: {percent_correct}") |
|
|
|
|
|
|