import os import pandas as pd import numpy as np from sklearn.model_selection import KFold from sklearn.metrics import mean_squared_error, r2_score from scipy.stats import pearsonr, spearmanr from catboost import CatBoostRegressor import matplotlib.pyplot as plt # Set publication-style fonts plt.rcParams.update({ 'font.family': 'serif', 'font.size': 13, 'axes.labelsize': 14, 'axes.titlesize': 14, 'xtick.labelsize': 12, 'ytick.labelsize': 12, 'legend.fontsize': 12 }) # Load dataset data = pd.read_csv("/storage/group/cdm8/default/BindPred/embeddings/xxx.csv") # Handle missing values data['Ligand Features'] = data['Ligand Features'].fillna('') data['Receptor Features'] = data['Receptor Features'].fillna('') # Convert embedding strings to float lists data['Ligand Features'] = data['Ligand Features'].apply( lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else [] ) data['Receptor Features'] = data['Receptor Features'].apply( lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else [] ) # Combine embeddings data['Combined Features'] = data.apply( lambda row: np.concatenate((row['Ligand Features'], row['Receptor Features'])) if len(row['Ligand Features']) > 0 and len(row['Receptor Features']) > 0 else np.array([]), axis=1 ) # Filter valid rows data = data[data['Combined Features'].apply(len) > 0] # Check KD(M) column if "KD(M)" not in data.columns or data["KD(M)"].isnull().any(): raise ValueError("Missing or NaN values in 'KD(M)' column.") # Prepare features and log-transformed labels X = np.vstack(data['Combined Features']) y = np.log10(data['KD(M)']) # Cross-validation kf = KFold(n_splits=5, shuffle=True, random_state=42) all_y_true = [] all_y_pred = [] test_indices_all = [] # Output directory output_dir = "new_plt" os.makedirs(output_dir, exist_ok=True) for fold, (train_index, test_index) in enumerate(kf.split(X)): X_train, X_test = X[train_index], X[test_index] y_train, y_test = y.iloc[train_index], y.iloc[test_index] model = CatBoostRegressor( iterations=2000, learning_rate=0.08, depth=4, verbose=500, task_type="GPU", devices='0' ) model.fit(X_train, y_train) y_pred = model.predict(X_test) all_y_true.extend(y_test) all_y_pred.extend(y_pred) test_indices_all.extend(test_index) # Convert predictions to arrays all_y_true = np.array(all_y_true) all_y_pred = np.array(all_y_pred) # Compute performance metrics pcc, _ = pearsonr(all_y_true, all_y_pred) srcc, _ = spearmanr(all_y_true, all_y_pred) rmse = np.sqrt(mean_squared_error(all_y_true, all_y_pred)) r2 = r2_score(all_y_true, all_y_pred) # Compute absolute error errors = np.abs(all_y_true - all_y_pred) # Plotting plt.figure(figsize=(5, 5)) plt.title("ESM2 Embeddings", fontsize=15, pad=10) sc = plt.scatter( all_y_true, all_y_pred, s=25, c=errors, cmap='Reds', alpha=0.9, edgecolors='black', linewidth=0.4, marker='^' # triangle markers ) # Diagonal reference line plt.plot([-15, -2], [-15, -2], color='black', linestyle='--', linewidth=1) # Axis setup plt.xlabel("Experimental Log10(Kd)", fontsize=14, labelpad=10) plt.ylabel("BindPred Prediction of Log10(Kd)", fontsize=14, labelpad=10) plt.xlim(-15.0, -2.0) plt.ylim(-15.0, -2.0) plt.gca().set_aspect('equal', adjustable='box') # Metrics box plt.text(0.05, 0.95, f"PCC: {pcc:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}", transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3')) # Colorbar cbar = plt.colorbar(sc) cbar.set_label("Absolute Error", fontsize=12) # Save plot plt.tight_layout() plt.savefig(os.path.join(output_dir, 'esm2_plot.png'), dpi=700) plt.savefig(os.path.join(output_dir, 'esm2_plot.pdf'), dpi=700) plt.show() # Save prediction results to CSV records = [] for idx, test_idx in enumerate(test_indices_all): row = data.iloc[test_idx] record = { "PDB_ID": row.get("PDB_ID", "NA"), "Mutation": row.get("Mutation", "NA"), "Actual_log10Kd": all_y_true[idx], "Predicted_log10Kd": all_y_pred[idx] } records.append(record) df_preds = pd.DataFrame(records) csv_path = os.path.join(output_dir, "ESM2_predictions.csv") df_preds.to_csv(csv_path, index=False) print(f"Saved prediction results to {csv_path}")