| | from typing import Dict, List, Any |
| | from transformers import pipeline |
| | import soundfile as sf |
| | import torch |
| | import logging |
| | import base64 |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | |
| | self.pipeline = pipeline("text-to-audio", "facebook/musicgen-stereo-large", device="cuda", torch_dtype=torch.float16) |
| |
|
| | def generate_audio(self, text: str): |
| | |
| | |
| | logger.info("Generating audio for text: %s", text) |
| | try: |
| | music = self.pipeline(text, forward_params={"max_new_tokens": 256}) |
| | return music["audio"][0].T, music["sampling_rate"] |
| | except Exception as e: |
| | logger.error("Error generating audio for text: %s", text, exc_info=True) |
| | raise e |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | input = data.pop("inputs", data) |
| | |
| |
|
| | audio_data, sampling_rate = self.generate_audio(input) |
| |
|
| | |
| | response = { |
| | "audio_data": audio_data.tolist(), |
| | "sampling_rate": sampling_rate |
| | } |
| |
|
| | return response |
| |
|
| |
|