Firdavs222 commited on
Commit
7e10eb3
·
verified ·
1 Parent(s): c4c4cd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -74
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration, Gene
3
  import torch
4
  import torchaudio
5
  import numpy as np
6
- import av # Ensure you have installed this: pip install av
7
 
8
  # --- Configuration and Model Loading ---
9
  model_id = "OvozifyLabs/whisper-small-uz-v1"
@@ -19,7 +19,6 @@ try:
19
  model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
20
  except Exception as e:
21
  print(f"Error loading model or processor: {e}")
22
- # Handle the error gracefully if the model cannot be loaded
23
  processor = None
24
  model = None
25
 
@@ -31,21 +30,19 @@ def load_audio_file(file_path):
31
  Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
32
  resampled to 16000 Hz and converted to mono, which Whisper models require.
33
  """
34
- sr_target = 16000 # Target sampling rate for the Whisper model
35
 
36
  if not file_path:
37
  raise FileNotFoundError("Audio file path is empty.")
38
 
39
  audio_data_list = []
40
- current_sr = sr_target # Assume target SR initially
41
 
42
  try:
43
- # 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well)
44
  audio, sr = torchaudio.load(file_path)
45
  current_sr = sr
46
 
47
- # If torchaudio succeeds, perform necessary post-loading processing
48
-
49
  # Resample if needed
50
  if current_sr != sr_target:
51
  if audio.dtype != torch.float32:
@@ -55,43 +52,33 @@ def load_audio_file(file_path):
55
  audio = resampler(audio)
56
  current_sr = sr_target
57
 
58
- # Convert to mono if necessary (take the mean across channels)
59
  if audio.shape[0] > 1:
60
  audio = torch.mean(audio, dim=0, keepdim=True)
61
 
62
  return audio, current_sr
63
 
64
  except Exception as torchaudio_e:
65
- # 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3
66
- # print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}")
67
-
68
  try:
69
  import av
70
  with av.open(file_path) as container:
71
  stream = container.streams.audio[0]
72
 
73
- # Set up a resampler to ensure 16kHz float mono output
74
  resampler = av.AudioResampler(
75
- format='fltp', # 32-bit floating point
76
- layout='mono', # Force mono output
77
- rate=sr_target # Target sampling rate 16000 Hz
78
  )
79
 
80
- # Decode the audio stream and resample frames
81
  for frame in container.decode(stream):
82
  for resampled_frame in resampler.resample(frame):
83
- # *** FIX APPLIED HERE: Removed 'format' keyword argument ***
84
- # to_ndarray() converts the frame to a NumPy array.
85
- # For a mono stream, [0] selects the single channel's data.
86
  audio_data_list.append(resampled_frame.to_ndarray()[0])
87
 
88
-
89
  if not audio_data_list:
90
  raise RuntimeError("Could not decode audio frames using PyAV.")
91
 
92
- # Concatenate all the 1D NumPy arrays into a single, continuous array
93
  audio_np = np.concatenate(audio_data_list, axis=0)
94
- # Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono)
95
  audio = torch.from_numpy(audio_np).unsqueeze(0).float()
96
 
97
  return audio, sr_target
@@ -99,25 +86,47 @@ def load_audio_file(file_path):
99
  except Exception as av_e:
100
  raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
101
 
102
- # Note: The main `transcribe_audio` function and the Gradio setup do not need changes.
103
- # Just replace this one function and restart your application.
104
-
105
- # --- Post-Loading Processing (Only executes if torchaudio succeeded) ---
106
-
107
- # Resample if needed (if torchaudio succeeded but the rate was wrong)
108
- if current_sr != sr_target:
109
- if audio_data.dtype != torch.float32:
110
- audio_data = audio_data.float()
111
-
112
- resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target)
113
- audio_data = resampler(audio_data)
114
- current_sr = sr_target
115
 
116
- # Convert to mono if necessary (take the mean across channels)
117
- if audio_data.shape[0] > 1:
118
- audio_data = torch.mean(audio_data, dim=0, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- return audio_data, current_sr
 
 
 
 
 
 
121
 
122
 
123
  # --- Transcription Function ---
@@ -125,6 +134,7 @@ def load_audio_file(file_path):
125
  def transcribe_audio(audio_file_path, language):
126
  """
127
  Transcribes an audio file using the pre-loaded Whisper model.
 
128
  """
129
  if model is None:
130
  return "Error: Model was not loaded successfully at startup."
@@ -141,72 +151,95 @@ def transcribe_audio(audio_file_path, language):
141
  language = lang_dict[language]
142
 
143
  try:
144
- # Load audio using the robust loader and get the 16kHz mono tensor
145
  audio, sr = load_audio_file(audio_file_path)
146
-
147
- # The processor expects a 1D NumPy array for raw audio input
148
- # audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array
149
- inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
150
 
151
- # Move inputs to the appropriate device
152
- input_features = inputs.input_features.to(device)
153
-
154
- forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
155
-
156
- gen_config = GenerationConfig(
157
- forced_decoder_ids=forced_ids,
158
- max_length=448
159
- )
160
-
161
- with torch.no_grad():
162
- predicted_ids = model.generate(
163
- input_features,
164
- generation_config=gen_config
165
- )
166
 
167
- # Decode the generated token IDs to get the text transcript
168
- text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  except Exception as e:
173
  return f"An error occurred during transcription: {e}"
174
 
175
 
176
  # --- Gradio Interface Setup ---
177
- # 🖼️ Interface Description
178
  title = "Whisper Small Uz v1: Multilingual audio transcription"
179
- description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly."
 
180
 
181
  language_input = gr.Dropdown(
182
  label="Select Language",
183
  choices=["Uzbek", "English", "Russian"],
184
- value="Uzbek" # default
185
  )
186
 
187
- # 🎤 Input Component
188
  audio_input = gr.Audio(
189
  sources=["microphone", "upload"],
190
  type="filepath",
191
  label="Input Audio (M4A/MP3/WAV, etc.)"
192
  )
193
 
194
- # 📝 Output Component
195
- text_output = gr.Textbox(label="Transcription Result", lines=6, max_lines = 25)
196
 
197
- # 🚀 Create the Interface
198
  demo = gr.Interface(
199
  fn=transcribe_audio,
200
  inputs=[audio_input, language_input],
201
  outputs=text_output,
202
  title=title,
203
  description=description,
204
- # The 'allow_flagging' argument caused the TypeError and is removed/replaced
205
- # 'flagging_enabled=None' disables the flagging button, which is cleaner
206
- # flagging_enabled=None,
207
- # theme=gr.themes.Soft()
208
  )
209
 
210
- # 💻 Launch the App
211
  if __name__ == "__main__":
212
  demo.launch()
 
3
  import torch
4
  import torchaudio
5
  import numpy as np
6
+ import av
7
 
8
  # --- Configuration and Model Loading ---
9
  model_id = "OvozifyLabs/whisper-small-uz-v1"
 
19
  model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
20
  except Exception as e:
21
  print(f"Error loading model or processor: {e}")
 
22
  processor = None
23
  model = None
24
 
 
30
  Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is
31
  resampled to 16000 Hz and converted to mono, which Whisper models require.
32
  """
33
+ sr_target = 16000
34
 
35
  if not file_path:
36
  raise FileNotFoundError("Audio file path is empty.")
37
 
38
  audio_data_list = []
39
+ current_sr = sr_target
40
 
41
  try:
42
+ # Try torchaudio's built-in loader first
43
  audio, sr = torchaudio.load(file_path)
44
  current_sr = sr
45
 
 
 
46
  # Resample if needed
47
  if current_sr != sr_target:
48
  if audio.dtype != torch.float32:
 
52
  audio = resampler(audio)
53
  current_sr = sr_target
54
 
55
+ # Convert to mono if necessary
56
  if audio.shape[0] > 1:
57
  audio = torch.mean(audio, dim=0, keepdim=True)
58
 
59
  return audio, current_sr
60
 
61
  except Exception as torchaudio_e:
62
+ # Fallback to PyAV for formats like M4A, MP3
 
 
63
  try:
64
  import av
65
  with av.open(file_path) as container:
66
  stream = container.streams.audio[0]
67
 
 
68
  resampler = av.AudioResampler(
69
+ format='fltp',
70
+ layout='mono',
71
+ rate=sr_target
72
  )
73
 
 
74
  for frame in container.decode(stream):
75
  for resampled_frame in resampler.resample(frame):
 
 
 
76
  audio_data_list.append(resampled_frame.to_ndarray()[0])
77
 
 
78
  if not audio_data_list:
79
  raise RuntimeError("Could not decode audio frames using PyAV.")
80
 
 
81
  audio_np = np.concatenate(audio_data_list, axis=0)
 
82
  audio = torch.from_numpy(audio_np).unsqueeze(0).float()
83
 
84
  return audio, sr_target
 
86
  except Exception as av_e:
87
  raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # --- Audio Chunking Function ---
91
+
92
+ def chunk_audio(audio_tensor, sampling_rate, chunk_length_s=30, overlap_s=5):
93
+ """
94
+ Splits audio into overlapping chunks.
95
+
96
+ Args:
97
+ audio_tensor: torch.Tensor of shape (1, num_samples) - mono audio
98
+ sampling_rate: int - sampling rate of the audio
99
+ chunk_length_s: float - length of each chunk in seconds
100
+ overlap_s: float - overlap between chunks in seconds
101
+
102
+ Returns:
103
+ List of audio chunks (torch.Tensors)
104
+ """
105
+ chunk_samples = int(chunk_length_s * sampling_rate)
106
+ overlap_samples = int(overlap_s * sampling_rate)
107
+ stride = chunk_samples - overlap_samples
108
+
109
+ audio_length = audio_tensor.shape[1]
110
+ chunks = []
111
+
112
+ # If audio is shorter than chunk length, return as single chunk
113
+ if audio_length <= chunk_samples:
114
+ return [audio_tensor]
115
+
116
+ # Split into chunks with overlap
117
+ start = 0
118
+ while start < audio_length:
119
+ end = min(start + chunk_samples, audio_length)
120
+ chunk = audio_tensor[:, start:end]
121
+ chunks.append(chunk)
122
 
123
+ # Break if we've reached the end
124
+ if end >= audio_length:
125
+ break
126
+
127
+ start += stride
128
+
129
+ return chunks
130
 
131
 
132
  # --- Transcription Function ---
 
134
  def transcribe_audio(audio_file_path, language):
135
  """
136
  Transcribes an audio file using the pre-loaded Whisper model.
137
+ Automatically chunks audio longer than 30 seconds.
138
  """
139
  if model is None:
140
  return "Error: Model was not loaded successfully at startup."
 
151
  language = lang_dict[language]
152
 
153
  try:
154
+ # Load audio using the robust loader
155
  audio, sr = load_audio_file(audio_file_path)
 
 
 
 
156
 
157
+ # Calculate audio duration
158
+ duration_s = audio.shape[1] / sr
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Check if chunking is needed
161
+ if duration_s > 30:
162
+ print(f"Audio duration: {duration_s:.2f}s - Chunking into segments...")
163
+ chunks = chunk_audio(audio, sr, chunk_length_s=30, overlap_s=5)
164
+
165
+ # Transcribe each chunk
166
+ transcriptions = []
167
+ for i, chunk in enumerate(chunks):
168
+ print(f"Processing chunk {i+1}/{len(chunks)}...")
169
+
170
+ inputs = processor(chunk.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
171
+ input_features = inputs.input_features.to(device)
172
+
173
+ forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
174
+ gen_config = GenerationConfig(
175
+ forced_decoder_ids=forced_ids,
176
+ max_length=448
177
+ )
178
+
179
+ with torch.no_grad():
180
+ predicted_ids = model.generate(
181
+ input_features,
182
+ generation_config=gen_config
183
+ )
184
+
185
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
186
+ transcriptions.append(text)
187
+
188
+ # Combine all transcriptions
189
+ full_transcription = " ".join(transcriptions)
190
+ return f"[Audio duration: {duration_s:.2f}s - Processed in {len(chunks)} chunks]\n\n{full_transcription}"
191
 
192
+ else:
193
+ # Process normally for short audio
194
+ print(f"Audio duration: {duration_s:.2f}s - Processing as single segment...")
195
+ inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
196
+ input_features = inputs.input_features.to(device)
197
+
198
+ forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
199
+ gen_config = GenerationConfig(
200
+ forced_decoder_ids=forced_ids,
201
+ max_length=448
202
+ )
203
+
204
+ with torch.no_grad():
205
+ predicted_ids = model.generate(
206
+ input_features,
207
+ generation_config=gen_config
208
+ )
209
+
210
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
211
+ return text
212
 
213
  except Exception as e:
214
  return f"An error occurred during transcription: {e}"
215
 
216
 
217
  # --- Gradio Interface Setup ---
 
218
  title = "Whisper Small Uz v1: Multilingual audio transcription"
219
+ description = """A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR.
220
+ Upload an audio file (M4A, MP3, WAV supported) or record directly. """
221
 
222
  language_input = gr.Dropdown(
223
  label="Select Language",
224
  choices=["Uzbek", "English", "Russian"],
225
+ value="Uzbek"
226
  )
227
 
 
228
  audio_input = gr.Audio(
229
  sources=["microphone", "upload"],
230
  type="filepath",
231
  label="Input Audio (M4A/MP3/WAV, etc.)"
232
  )
233
 
234
+ text_output = gr.Textbox(label="Transcription Result", lines=6, max_lines=25)
 
235
 
 
236
  demo = gr.Interface(
237
  fn=transcribe_audio,
238
  inputs=[audio_input, language_input],
239
  outputs=text_output,
240
  title=title,
241
  description=description,
 
 
 
 
242
  )
243
 
 
244
  if __name__ == "__main__":
245
  demo.launch()