File size: 4,529 Bytes
4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 31332ec 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa 4f89f53 877c6aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from catboost import CatBoostRegressor
import matplotlib.pyplot as plt
# Set publication-style fonts
plt.rcParams.update({
'font.family': 'serif',
'font.size': 13,
'axes.labelsize': 14,
'axes.titlesize': 14,
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'legend.fontsize': 12
})
# Load dataset
data = pd.read_csv("/storage/group/cdm8/default/BindPred/embeddings/xxx.csv")
# Handle missing values
data['Ligand Features'] = data['Ligand Features'].fillna('')
data['Receptor Features'] = data['Receptor Features'].fillna('')
# Convert embedding strings to float lists
data['Ligand Features'] = data['Ligand Features'].apply(
lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)
data['Receptor Features'] = data['Receptor Features'].apply(
lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)
# Combine embeddings
data['Combined Features'] = data.apply(
lambda row: np.concatenate((row['Ligand Features'], row['Receptor Features']))
if len(row['Ligand Features']) > 0 and len(row['Receptor Features']) > 0 else np.array([]),
axis=1
)
# Filter valid rows
data = data[data['Combined Features'].apply(len) > 0]
# Check KD(M) column
if "KD(M)" not in data.columns or data["KD(M)"].isnull().any():
raise ValueError("Missing or NaN values in 'KD(M)' column.")
# Prepare features and log-transformed labels
X = np.vstack(data['Combined Features'])
y = np.log10(data['KD(M)'])
# Cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
all_y_true = []
all_y_pred = []
test_indices_all = []
# Output directory
output_dir = "new_plt"
os.makedirs(output_dir, exist_ok=True)
for fold, (train_index, test_index) in enumerate(kf.split(X)):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
model = CatBoostRegressor(
iterations=2000,
learning_rate=0.08,
depth=4,
verbose=500,
task_type="GPU",
devices='0'
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
all_y_true.extend(y_test)
all_y_pred.extend(y_pred)
test_indices_all.extend(test_index)
# Convert predictions to arrays
all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)
# Compute performance metrics
pcc, _ = pearsonr(all_y_true, all_y_pred)
srcc, _ = spearmanr(all_y_true, all_y_pred)
rmse = np.sqrt(mean_squared_error(all_y_true, all_y_pred))
r2 = r2_score(all_y_true, all_y_pred)
# Compute absolute error
errors = np.abs(all_y_true - all_y_pred)
# Plotting
plt.figure(figsize=(5, 5))
plt.title("ESM2 Embeddings", fontsize=15, pad=10)
sc = plt.scatter(
all_y_true,
all_y_pred,
s=25,
c=errors,
cmap='Reds',
alpha=0.9,
edgecolors='black',
linewidth=0.4,
marker='^' # triangle markers
)
# Diagonal reference line
plt.plot([-15, -2], [-15, -2], color='black', linestyle='--', linewidth=1)
# Axis setup
plt.xlabel("Experimental Log10(Kd)", fontsize=14, labelpad=10)
plt.ylabel("BindPred Prediction of Log10(Kd)", fontsize=14, labelpad=10)
plt.xlim(-15.0, -2.0)
plt.ylim(-15.0, -2.0)
plt.gca().set_aspect('equal', adjustable='box')
# Metrics box
plt.text(0.05, 0.95,
f"PCC: {pcc:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}",
transform=plt.gca().transAxes,
fontsize=12,
verticalalignment='top',
horizontalalignment='left',
bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3'))
# Colorbar
cbar = plt.colorbar(sc)
cbar.set_label("Absolute Error", fontsize=12)
# Save plot
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'esm2_plot.png'), dpi=700)
plt.savefig(os.path.join(output_dir, 'esm2_plot.pdf'), dpi=700)
plt.show()
# Save prediction results to CSV
records = []
for idx, test_idx in enumerate(test_indices_all):
row = data.iloc[test_idx]
record = {
"PDB_ID": row.get("PDB_ID", "NA"),
"Mutation": row.get("Mutation", "NA"),
"Actual_log10Kd": all_y_true[idx],
"Predicted_log10Kd": all_y_pred[idx]
}
records.append(record)
df_preds = pd.DataFrame(records)
csv_path = os.path.join(output_dir, "ESM2_predictions.csv")
df_preds.to_csv(csv_path, index=False)
print(f"Saved prediction results to {csv_path}") |