File size: 4,529 Bytes
4f89f53
 
 
 
 
877c6aa
4f89f53
877c6aa
4f89f53
877c6aa
 
 
 
 
 
 
 
 
 
4f89f53
877c6aa
31332ec
4f89f53
877c6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f89f53
877c6aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f89f53
877c6aa
 
 
4f89f53
877c6aa
 
 
 
 
4f89f53
877c6aa
4f89f53
877c6aa
 
 
 
 
 
 
 
 
4f89f53
877c6aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from catboost import CatBoostRegressor
import matplotlib.pyplot as plt

# Set publication-style fonts
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 13,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12
})

# Load dataset
data = pd.read_csv("/storage/group/cdm8/default/BindPred/embeddings/xxx.csv")

# Handle missing values
data['Ligand Features'] = data['Ligand Features'].fillna('')
data['Receptor Features'] = data['Receptor Features'].fillna('')

# Convert embedding strings to float lists
data['Ligand Features'] = data['Ligand Features'].apply(
    lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)
data['Receptor Features'] = data['Receptor Features'].apply(
    lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
)

# Combine embeddings
data['Combined Features'] = data.apply(
    lambda row: np.concatenate((row['Ligand Features'], row['Receptor Features']))
    if len(row['Ligand Features']) > 0 and len(row['Receptor Features']) > 0 else np.array([]),
    axis=1
)

# Filter valid rows
data = data[data['Combined Features'].apply(len) > 0]

# Check KD(M) column
if "KD(M)" not in data.columns or data["KD(M)"].isnull().any():
    raise ValueError("Missing or NaN values in 'KD(M)' column.")

# Prepare features and log-transformed labels
X = np.vstack(data['Combined Features'])
y = np.log10(data['KD(M)'])

# Cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
all_y_true = []
all_y_pred = []
test_indices_all = []

# Output directory
output_dir = "new_plt"
os.makedirs(output_dir, exist_ok=True)

for fold, (train_index, test_index) in enumerate(kf.split(X)):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

    model = CatBoostRegressor(
        iterations=2000,
        learning_rate=0.08,
        depth=4,
        verbose=500,
        task_type="GPU",
        devices='0'
    )
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    all_y_true.extend(y_test)
    all_y_pred.extend(y_pred)
    test_indices_all.extend(test_index)

# Convert predictions to arrays
all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)

# Compute performance metrics
pcc, _ = pearsonr(all_y_true, all_y_pred)
srcc, _ = spearmanr(all_y_true, all_y_pred)
rmse = np.sqrt(mean_squared_error(all_y_true, all_y_pred))
r2 = r2_score(all_y_true, all_y_pred)

# Compute absolute error
errors = np.abs(all_y_true - all_y_pred)

# Plotting
plt.figure(figsize=(5, 5))
plt.title("ESM2 Embeddings", fontsize=15, pad=10)

sc = plt.scatter(
    all_y_true,
    all_y_pred,
    s=25,
    c=errors,
    cmap='Reds',
    alpha=0.9,
    edgecolors='black',
    linewidth=0.4,
    marker='^'  # triangle markers
)

# Diagonal reference line
plt.plot([-15, -2], [-15, -2], color='black', linestyle='--', linewidth=1)

# Axis setup
plt.xlabel("Experimental Log10(Kd)", fontsize=14, labelpad=10)
plt.ylabel("BindPred Prediction of Log10(Kd)", fontsize=14, labelpad=10)
plt.xlim(-15.0, -2.0)
plt.ylim(-15.0, -2.0)
plt.gca().set_aspect('equal', adjustable='box')

# Metrics box
plt.text(0.05, 0.95,
         f"PCC: {pcc:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}",
         transform=plt.gca().transAxes,
         fontsize=12,
         verticalalignment='top',
         horizontalalignment='left',
         bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3'))

# Colorbar
cbar = plt.colorbar(sc)
cbar.set_label("Absolute Error", fontsize=12)

# Save plot
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'esm2_plot.png'), dpi=700)
plt.savefig(os.path.join(output_dir, 'esm2_plot.pdf'), dpi=700)
plt.show()

# Save prediction results to CSV
records = []
for idx, test_idx in enumerate(test_indices_all):
    row = data.iloc[test_idx]
    record = {
        "PDB_ID": row.get("PDB_ID", "NA"),
        "Mutation": row.get("Mutation", "NA"),
        "Actual_log10Kd": all_y_true[idx],
        "Predicted_log10Kd": all_y_pred[idx]
    }
    records.append(record)

df_preds = pd.DataFrame(records)
csv_path = os.path.join(output_dir, "ESM2_predictions.csv")
df_preds.to_csv(csv_path, index=False)
print(f"Saved prediction results to {csv_path}")