BindPred / train.py
hbp5181's picture
Update train.py
31332ec verified
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from catboost import CatBoostRegressor
import matplotlib.pyplot as plt
# Set publication-style fonts
plt.rcParams.update({
'font.family': 'serif',
'font.size': 13,
'axes.labelsize': 14,
'axes.titlesize': 14,
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'legend.fontsize': 12
})
# Load dataset
data = pd.read_csv("/storage/group/cdm8/default/BindPred/embeddings/xxx.csv")
# Handle missing values
data['Ligand Features'] = data['Ligand Features'].fillna('')
data['Receptor Features'] = data['Receptor Features'].fillna('')
# Convert embedding strings to float lists
data['Ligand Features'] = data['Ligand Features'].apply(
lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)
data['Receptor Features'] = data['Receptor Features'].apply(
lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)
# Combine embeddings
data['Combined Features'] = data.apply(
lambda row: np.concatenate((row['Ligand Features'], row['Receptor Features']))
if len(row['Ligand Features']) > 0 and len(row['Receptor Features']) > 0 else np.array([]),
axis=1
)
# Filter valid rows
data = data[data['Combined Features'].apply(len) > 0]
# Check KD(M) column
if "KD(M)" not in data.columns or data["KD(M)"].isnull().any():
raise ValueError("Missing or NaN values in 'KD(M)' column.")
# Prepare features and log-transformed labels
X = np.vstack(data['Combined Features'])
y = np.log10(data['KD(M)'])
# Cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
all_y_true = []
all_y_pred = []
test_indices_all = []
# Output directory
output_dir = "new_plt"
os.makedirs(output_dir, exist_ok=True)
for fold, (train_index, test_index) in enumerate(kf.split(X)):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
model = CatBoostRegressor(
iterations=2000,
learning_rate=0.08,
depth=4,
verbose=500,
task_type="GPU",
devices='0'
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
all_y_true.extend(y_test)
all_y_pred.extend(y_pred)
test_indices_all.extend(test_index)
# Convert predictions to arrays
all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)
# Compute performance metrics
pcc, _ = pearsonr(all_y_true, all_y_pred)
srcc, _ = spearmanr(all_y_true, all_y_pred)
rmse = np.sqrt(mean_squared_error(all_y_true, all_y_pred))
r2 = r2_score(all_y_true, all_y_pred)
# Compute absolute error
errors = np.abs(all_y_true - all_y_pred)
# Plotting
plt.figure(figsize=(5, 5))
plt.title("ESM2 Embeddings", fontsize=15, pad=10)
sc = plt.scatter(
all_y_true,
all_y_pred,
s=25,
c=errors,
cmap='Reds',
alpha=0.9,
edgecolors='black',
linewidth=0.4,
marker='^' # triangle markers
)
# Diagonal reference line
plt.plot([-15, -2], [-15, -2], color='black', linestyle='--', linewidth=1)
# Axis setup
plt.xlabel("Experimental Log10(Kd)", fontsize=14, labelpad=10)
plt.ylabel("BindPred Prediction of Log10(Kd)", fontsize=14, labelpad=10)
plt.xlim(-15.0, -2.0)
plt.ylim(-15.0, -2.0)
plt.gca().set_aspect('equal', adjustable='box')
# Metrics box
plt.text(0.05, 0.95,
f"PCC: {pcc:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}",
transform=plt.gca().transAxes,
fontsize=12,
verticalalignment='top',
horizontalalignment='left',
bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3'))
# Colorbar
cbar = plt.colorbar(sc)
cbar.set_label("Absolute Error", fontsize=12)
# Save plot
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'esm2_plot.png'), dpi=700)
plt.savefig(os.path.join(output_dir, 'esm2_plot.pdf'), dpi=700)
plt.show()
# Save prediction results to CSV
records = []
for idx, test_idx in enumerate(test_indices_all):
row = data.iloc[test_idx]
record = {
"PDB_ID": row.get("PDB_ID", "NA"),
"Mutation": row.get("Mutation", "NA"),
"Actual_log10Kd": all_y_true[idx],
"Predicted_log10Kd": all_y_pred[idx]
}
records.append(record)
df_preds = pd.DataFrame(records)
csv_path = os.path.join(output_dir, "ESM2_predictions.csv")
df_preds.to_csv(csv_path, index=False)
print(f"Saved prediction results to {csv_path}")