Spaces:
Running
Running
Auto-deploy from GitHub: 9ca7e761f50169763dd91ccaf7c76c6bdcbe90b2
Browse files- scripts/predict.py +6 -2
- src/models/mlp.py +8 -6
scripts/predict.py
CHANGED
|
@@ -31,13 +31,17 @@ def predict_pipeline(audio_file, lyrics):
|
|
| 31 |
label : int
|
| 32 |
A numerical representation of the prediction
|
| 33 |
"""
|
| 34 |
-
|
| 35 |
# 1.) Instantiate LLM2Vec Model
|
| 36 |
llm2vec_model = load_llm2vec_model()
|
| 37 |
|
| 38 |
# 2.) Preprocess both audio and lyrics
|
| 39 |
audio, lyrics = single_preprocessing(audio_file, lyrics)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# 3.) Call the train method for both models
|
| 42 |
audio_features = spectttra_predict(audio)
|
| 43 |
lyrics_features = l2vec_single_train(llm2vec_model, lyrics)
|
|
@@ -59,7 +63,7 @@ def predict_pipeline(audio_file, lyrics):
|
|
| 59 |
config = load_config("config/model_config.yml")
|
| 60 |
classifier = build_mlp(input_dim=results.shape[1], config=config)
|
| 61 |
|
| 62 |
-
# 7.) Load trained weights
|
| 63 |
model_path = "models/mlp/mlp_best.pth"
|
| 64 |
classifier.load_model(model_path)
|
| 65 |
classifier.model.eval()
|
|
|
|
| 31 |
label : int
|
| 32 |
A numerical representation of the prediction
|
| 33 |
"""
|
|
|
|
| 34 |
# 1.) Instantiate LLM2Vec Model
|
| 35 |
llm2vec_model = load_llm2vec_model()
|
| 36 |
|
| 37 |
# 2.) Preprocess both audio and lyrics
|
| 38 |
audio, lyrics = single_preprocessing(audio_file, lyrics)
|
| 39 |
|
| 40 |
+
# Truncate to 2 minutes to match explain pipeline
|
| 41 |
+
target_samples = int(2 * 60 * 22050)
|
| 42 |
+
if len(audio) > target_samples:
|
| 43 |
+
audio = audio[:target_samples]
|
| 44 |
+
|
| 45 |
# 3.) Call the train method for both models
|
| 46 |
audio_features = spectttra_predict(audio)
|
| 47 |
lyrics_features = l2vec_single_train(llm2vec_model, lyrics)
|
|
|
|
| 63 |
config = load_config("config/model_config.yml")
|
| 64 |
classifier = build_mlp(input_dim=results.shape[1], config=config)
|
| 65 |
|
| 66 |
+
# 7.) Load trained weights
|
| 67 |
model_path = "models/mlp/mlp_best.pth"
|
| 68 |
classifier.load_model(model_path)
|
| 69 |
classifier.model.eval()
|
src/models/mlp.py
CHANGED
|
@@ -442,7 +442,9 @@ class MLPClassifier:
|
|
| 442 |
|
| 443 |
return probabilities, predictions
|
| 444 |
|
| 445 |
-
def predict_single(
|
|
|
|
|
|
|
| 446 |
"""
|
| 447 |
Predict whether a single song is AI-generated or human-composed.
|
| 448 |
|
|
@@ -487,15 +489,15 @@ class MLPClassifier:
|
|
| 487 |
with torch.no_grad():
|
| 488 |
features_tensor = torch.FloatTensor(features).to(self.device)
|
| 489 |
outputs = self.model(features_tensor)
|
| 490 |
-
|
| 491 |
-
probabilities = torch.sigmoid(logit / temperature).item()
|
| 492 |
-
probabilities = np.clip(probabilities, 0.01, 0.99)
|
| 493 |
|
| 494 |
# Extract single results
|
| 495 |
prediction = int(probabilities >= 0.5)
|
| 496 |
label = "Human-Composed" if prediction == 1 else "AI-Generated"
|
| 497 |
-
probability =
|
| 498 |
-
|
|
|
|
|
|
|
| 499 |
return probability, prediction, label
|
| 500 |
|
| 501 |
def predict_batch(self, features: np.ndarray, return_details: bool = False) -> Dict:
|
|
|
|
| 442 |
|
| 443 |
return probabilities, predictions
|
| 444 |
|
| 445 |
+
def predict_single(
|
| 446 |
+
self, features: np.ndarray, temperature: float = 2.5
|
| 447 |
+
) -> Tuple[float, int, str]:
|
| 448 |
"""
|
| 449 |
Predict whether a single song is AI-generated or human-composed.
|
| 450 |
|
|
|
|
| 489 |
with torch.no_grad():
|
| 490 |
features_tensor = torch.FloatTensor(features).to(self.device)
|
| 491 |
outputs = self.model(features_tensor)
|
| 492 |
+
probabilities = outputs.item() # Just use raw output
|
|
|
|
|
|
|
| 493 |
|
| 494 |
# Extract single results
|
| 495 |
prediction = int(probabilities >= 0.5)
|
| 496 |
label = "Human-Composed" if prediction == 1 else "AI-Generated"
|
| 497 |
+
probability = (
|
| 498 |
+
probabilities * 100 if prediction == 1 else (1 - probabilities) * 100
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
return probability, prediction, label
|
| 502 |
|
| 503 |
def predict_batch(self, features: np.ndarray, return_details: bool = False) -> Dict:
|