houseaudrey12 commited on
Commit
fd81a0d
·
verified ·
1 Parent(s): 1cc4561

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -40
app.py CHANGED
@@ -1,36 +1,82 @@
1
  import gradio as gr
2
  import torch
 
 
3
  import torchaudio
4
  import numpy as np
5
  from datasets import load_dataset
6
- import torch.nn.functional as F
7
 
8
  # ---------------------------
9
- # Load Dataset for Label Reference
 
 
 
 
 
 
 
 
10
  # ---------------------------
11
  dataset = load_dataset("ccmusic-database/pianos", name="8_class")
12
  label_names = dataset["train"].features["label"].names
 
13
 
14
  # ---------------------------
15
- # Placeholder Models (will be replaced later with trained models)
16
  # ---------------------------
17
- def fake_classify(mel_spec):
18
- # Random label for now, just so the app runs
19
- return np.random.choice(label_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def fake_quality_score(mel_spec):
22
- # Random quality between 1 and 10 for now
23
- return round(float(np.random.uniform(1, 10)), 2)
 
 
 
 
24
 
25
  # ---------------------------
26
  # Audio Preprocessing
27
  # ---------------------------
28
-
29
- TARGET_SR = 44100
30
- N_FFT = 1024
31
- HOP_LENGTH = 512
32
- N_MELS = 64
33
-
34
  mel_transform = torchaudio.transforms.MelSpectrogram(
35
  sample_rate=TARGET_SR,
36
  n_fft=N_FFT,
@@ -39,14 +85,14 @@ mel_transform = torchaudio.transforms.MelSpectrogram(
39
  center=False # we will handle padding manually
40
  )
41
 
42
- def preprocess_audio(audio):
43
  """
44
- audio from gradio.Audio(type="numpy") is a tuple: (sample_rate, data)
45
- data is a NumPy array with shape (samples,) or (samples, channels)
46
  """
47
  sr, data = audio
48
 
49
- # Convert to torch tensor
50
  waveform = torch.tensor(data, dtype=torch.float32)
51
 
52
  # If shape is (samples,), make it (1, samples)
@@ -57,7 +103,7 @@ def preprocess_audio(audio):
57
  if waveform.ndim == 2 and waveform.shape[0] < waveform.shape[1]:
58
  waveform = waveform.transpose(0, 1)
59
 
60
- # Convert to mono if stereo or more channels
61
  if waveform.shape[0] > 1:
62
  waveform = waveform.mean(dim=0, keepdim=True)
63
 
@@ -65,43 +111,54 @@ def preprocess_audio(audio):
65
  if sr != TARGET_SR:
66
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
67
  waveform = resampler(waveform)
68
- sr = TARGET_SR
69
-
70
- # --- NEW: Ensure minimum length for STFT / MelSpectrogram ---
71
- min_len = N_FFT # at least one window
72
- current_len = waveform.shape[-1]
73
- if current_len < min_len:
74
- pad_amount = min_len - current_len
75
- # Pad at the end with zeros
76
  waveform = F.pad(waveform, (0, pad_amount))
77
 
78
- # Mel-spectrogram (no internal centering/padding)
79
- mel = mel_transform(waveform)
80
  mel_db = torchaudio.transforms.AmplitudeToDB()(mel)
81
- return mel_db
 
 
 
 
 
 
 
 
 
82
 
83
  # ---------------------------
84
- # Main Analyze Function
85
  # ---------------------------
86
  def analyze_piano(audio):
87
  if audio is None:
88
- return "Please upload or record a piano audio clip (at least 1–2 seconds)."
89
 
90
  try:
91
- mel = preprocess_audio(audio)
 
 
 
 
 
 
 
92
 
93
- # Placeholder predictions (to be replaced with real models later)
94
- piano_type = fake_classify(mel)
95
- quality_score = fake_quality_score(mel)
96
 
97
  output_text = (
98
  f"Piano Type Prediction: {piano_type}\n"
99
- f"Estimated Sound Quality Score: {quality_score} / 10"
100
  )
101
  return output_text
102
 
103
  except Exception as e:
104
- # Show error in the UI instead of crashing the app
105
  return f"An error occurred while processing the audio: {e}"
106
 
107
  # ---------------------------
@@ -116,7 +173,7 @@ demo = gr.Interface(
116
  ),
117
  outputs=gr.Textbox(label="AI Analysis Output"),
118
  title="AI Piano Sound Analyzer 🎹",
119
- description="Upload a short piano recording (around 1–3 seconds) to get a predicted piano type and estimated sound-quality score."
120
  )
121
 
122
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
  import torchaudio
6
  import numpy as np
7
  from datasets import load_dataset
 
8
 
9
  # ---------------------------
10
+ # Constants
11
+ # ---------------------------
12
+ TARGET_SR = 44100
13
+ N_FFT = 1024
14
+ HOP_LENGTH = 512
15
+ N_MELS = 64
16
+
17
+ # ---------------------------
18
+ # Load Dataset Metadata for Labels
19
  # ---------------------------
20
  dataset = load_dataset("ccmusic-database/pianos", name="8_class")
21
  label_names = dataset["train"].features["label"].names
22
+ num_classes = len(label_names)
23
 
24
  # ---------------------------
25
+ # Define the Same CNN Model as in Training
26
  # ---------------------------
27
+ class PianoCNNMultiTask(nn.Module):
28
+ def __init__(self, num_classes):
29
+ super().__init__()
30
+ self.features = nn.Sequential(
31
+ nn.Conv2d(3, 16, kernel_size=3, padding=1),
32
+ nn.BatchNorm2d(16),
33
+ nn.ReLU(),
34
+ nn.MaxPool2d(2), # 128 -> 64
35
+
36
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
37
+ nn.BatchNorm2d(32),
38
+ nn.ReLU(),
39
+ nn.MaxPool2d(2), # 64 -> 32
40
+
41
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(64),
43
+ nn.ReLU(),
44
+ nn.MaxPool2d(2), # 32 -> 16
45
+
46
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
47
+ nn.BatchNorm2d(128),
48
+ nn.ReLU(),
49
+ nn.AdaptiveAvgPool2d((4, 4)) # 4x4 feature map
50
+ )
51
+ self.flatten = nn.Flatten()
52
+ self.fc_shared = nn.Linear(128 * 4 * 4, 256)
53
+ self.dropout = nn.Dropout(0.3)
54
+
55
+ # Classification head
56
+ self.fc_class = nn.Linear(256, num_classes)
57
+ # Regression head (quality score)
58
+ self.fc_reg = nn.Linear(256, 1)
59
+
60
+ def forward(self, x):
61
+ x = self.features(x)
62
+ x = self.flatten(x)
63
+ x = F.relu(self.fc_shared(x))
64
+ x = self.dropout(x)
65
+ class_logits = self.fc_class(x)
66
+ quality_pred = self.fc_reg(x).squeeze(1)
67
+ return class_logits, quality_pred
68
 
69
+ # ---------------------------
70
+ # Initialize and Load Trained Model (CPU)
71
+ # ---------------------------
72
+ model = PianoCNNMultiTask(num_classes=num_classes)
73
+ state_dict = torch.load("piano_cnn_multitask.pt", map_location=torch.device("cpu"))
74
+ model.load_state_dict(state_dict)
75
+ model.eval() # inference mode
76
 
77
  # ---------------------------
78
  # Audio Preprocessing
79
  # ---------------------------
 
 
 
 
 
 
80
  mel_transform = torchaudio.transforms.MelSpectrogram(
81
  sample_rate=TARGET_SR,
82
  n_fft=N_FFT,
 
85
  center=False # we will handle padding manually
86
  )
87
 
88
+ def preprocess_audio_to_mel_image(audio):
89
  """
90
+ audio from gradio.Audio(type="numpy") is (sample_rate, data)
91
+ Returns a 3x128x128 tensor ready for the CNN.
92
  """
93
  sr, data = audio
94
 
95
+ # Convert to tensor
96
  waveform = torch.tensor(data, dtype=torch.float32)
97
 
98
  # If shape is (samples,), make it (1, samples)
 
103
  if waveform.ndim == 2 and waveform.shape[0] < waveform.shape[1]:
104
  waveform = waveform.transpose(0, 1)
105
 
106
+ # Convert to mono if stereo
107
  if waveform.shape[0] > 1:
108
  waveform = waveform.mean(dim=0, keepdim=True)
109
 
 
111
  if sr != TARGET_SR:
112
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
113
  waveform = resampler(waveform)
114
+
115
+ # Ensure minimum length for STFT
116
+ min_len = N_FFT
117
+ if waveform.shape[-1] < min_len:
118
+ pad_amount = min_len - waveform.shape[-1]
 
 
 
119
  waveform = F.pad(waveform, (0, pad_amount))
120
 
121
+ # Compute Mel-spectrogram and convert to dB
122
+ mel = mel_transform(waveform) # [1, n_mels, time]
123
  mel_db = torchaudio.transforms.AmplitudeToDB()(mel)
124
+
125
+ # Normalize to 0–1
126
+ mel_db = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-6)
127
+
128
+ # Resize to 128x128 and make 3 channels
129
+ mel_db = mel_db.unsqueeze(0) # [1, 1, H, W]
130
+ mel_resized = F.interpolate(mel_db, size=(128, 128), mode="bilinear", align_corners=False)
131
+ mel_rgb = mel_resized.repeat(1, 3, 1, 1) # [1, 3, 128, 128]
132
+
133
+ return mel_rgb.squeeze(0) # [3, 128, 128]
134
 
135
  # ---------------------------
136
+ # Main Inference Function
137
  # ---------------------------
138
  def analyze_piano(audio):
139
  if audio is None:
140
+ return "Please upload or record a piano audio clip (around 1–3 seconds)."
141
 
142
  try:
143
+ # Preprocess input
144
+ mel_img = preprocess_audio_to_mel_image(audio) # [3,128,128]
145
+ mel_batch = mel_img.unsqueeze(0) # [1,3,128,128]
146
+
147
+ with torch.no_grad():
148
+ logits, q_pred = model(mel_batch)
149
+ class_idx = torch.argmax(logits, dim=1).item()
150
+ quality_score = float(q_pred.item())
151
 
152
+ piano_type = label_names[class_idx]
153
+ quality_score_rounded = round(quality_score, 2)
 
154
 
155
  output_text = (
156
  f"Piano Type Prediction: {piano_type}\n"
157
+ f"Estimated Sound Quality Score: {quality_score_rounded} / 10"
158
  )
159
  return output_text
160
 
161
  except Exception as e:
 
162
  return f"An error occurred while processing the audio: {e}"
163
 
164
  # ---------------------------
 
173
  ),
174
  outputs=gr.Textbox(label="AI Analysis Output"),
175
  title="AI Piano Sound Analyzer 🎹",
176
+ description="Upload a short piano recording to get a predicted piano type and estimated sound-quality score from the trained CNN model."
177
  )
178
 
179
  if __name__ == "__main__":