Spaces:
Build error
Build error
| import os | |
| import pickle | |
| import torch | |
| from PIL import Image | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionImg2ImgPipeline, | |
| FluxPipeline, | |
| DiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| ) | |
| from transformers import ( | |
| pipeline as transformers_pipeline, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| GPT2Tokenizer, | |
| GPT2Model, | |
| AutoModel | |
| ) | |
| from audiocraft.models import musicgen | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download, HfApi, HfFolder | |
| import io | |
| import time | |
| from tqdm import tqdm | |
| from google.cloud import storage | |
| import json | |
| hf_token = os.getenv("HF_TOKEN") | |
| gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS")) | |
| gcs_bucket_name = os.getenv("GCS_BUCKET_NAME") | |
| HfFolder.save_token(hf_token) | |
| storage_client = storage.Client.from_service_account_info(gcs_credentials) | |
| bucket = storage_client.bucket(gcs_bucket_name) | |
| def load_object_from_gcs(blob_name): | |
| blob = bucket.blob(blob_name) | |
| if blob.exists(): | |
| return pickle.loads(blob.download_as_bytes()) | |
| return None | |
| def save_object_to_gcs(blob_name, obj): | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_string(pickle.dumps(obj)) | |
| def get_model_or_download(model_id, blob_name, loader_func): | |
| model = load_object_from_gcs(blob_name) | |
| if model: | |
| return model | |
| try: | |
| with tqdm(total=1, desc=f"Downloading {model_id}") as pbar: | |
| model = loader_func(model_id, torch_dtype=torch.float16) | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, model) | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load or save model: {e}") | |
| return None | |
| def generate_image(prompt): | |
| blob_name = f"diffusers/generated_image:{prompt}" | |
| image_bytes = load_object_from_gcs(blob_name) | |
| if not image_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating image") as pbar: | |
| image = text_to_image_pipeline(prompt).images[0] | |
| pbar.update(1) | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| image_bytes = buffered.getvalue() | |
| save_object_to_gcs(blob_name, image_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate image: {e}") | |
| return None | |
| return image_bytes | |
| def edit_image_with_prompt(image_bytes, prompt, strength=0.75): | |
| blob_name = f"diffusers/edited_image:{prompt}:{strength}" | |
| edited_image_bytes = load_object_from_gcs(blob_name) | |
| if not edited_image_bytes: | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| with tqdm(total=1, desc="Editing image") as pbar: | |
| edited_image = img2img_pipeline( | |
| prompt=prompt, image=image, strength=strength | |
| ).images[0] | |
| pbar.update(1) | |
| buffered = io.BytesIO() | |
| edited_image.save(buffered, format="JPEG") | |
| edited_image_bytes = buffered.getvalue() | |
| save_object_to_gcs(blob_name, edited_image_bytes) | |
| except Exception as e: | |
| print(f"Failed to edit image: {e}") | |
| return None | |
| return edited_image_bytes | |
| def generate_song(prompt, duration=10): | |
| blob_name = f"music/generated_song:{prompt}:{duration}" | |
| song_bytes = load_object_from_gcs(blob_name) | |
| if not song_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating song") as pbar: | |
| song = music_gen(prompt, duration=duration) | |
| pbar.update(1) | |
| song_bytes = song[0].getvalue() | |
| save_object_to_gcs(blob_name, song_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate song: {e}") | |
| return None | |
| return song_bytes | |
| def generate_text(prompt): | |
| blob_name = f"transformers/generated_text:{prompt}" | |
| text = load_object_from_gcs(blob_name) | |
| if not text: | |
| try: | |
| with tqdm(total=1, desc="Generating text") as pbar: | |
| text = text_gen_pipeline(prompt, max_new_tokens=256)[0][ | |
| "generated_text" | |
| ].strip() | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, text) | |
| except Exception as e: | |
| print(f"Failed to generate text: {e}") | |
| return None | |
| return text | |
| def generate_flux_image(prompt): | |
| blob_name = f"diffusers/generated_flux_image:{prompt}" | |
| flux_image_bytes = load_object_from_gcs(blob_name) | |
| if not flux_image_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating FLUX image") as pbar: | |
| flux_image = flux_pipeline( | |
| prompt, | |
| guidance_scale=0.0, | |
| num_inference_steps=4, | |
| max_length=256, | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| ).images[0] | |
| pbar.update(1) | |
| buffered = io.BytesIO() | |
| flux_image.save(buffered, format="JPEG") | |
| flux_image_bytes = buffered.getvalue() | |
| save_object_to_gcs(blob_name, flux_image_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate flux image: {e}") | |
| return None | |
| return flux_image_bytes | |
| def generate_code(prompt): | |
| blob_name = f"transformers/generated_code:{prompt}" | |
| code = load_object_from_gcs(blob_name) | |
| if not code: | |
| try: | |
| with tqdm(total=1, desc="Generating code") as pbar: | |
| inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt") | |
| outputs = starcoder_model.generate(inputs, max_new_tokens=256) | |
| code = starcoder_tokenizer.decode(outputs[0]) | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, code) | |
| except Exception as e: | |
| print(f"Failed to generate code: {e}") | |
| return None | |
| return code | |
| def test_model_meta_llama(): | |
| blob_name = "transformers/meta_llama_test_response" | |
| response = load_object_from_gcs(blob_name) | |
| if not response: | |
| try: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a pirate chatbot who always responds in pirate speak!", | |
| }, | |
| {"role": "user", "content": "Who are you?"}, | |
| ] | |
| with tqdm(total=1, desc="Testing Meta-Llama") as pbar: | |
| response = meta_llama_pipeline(messages, max_new_tokens=256)[0][ | |
| "generated_text" | |
| ].strip() | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, response) | |
| except Exception as e: | |
| print(f"Failed to test Meta-Llama: {e}") | |
| return None | |
| return response | |
| def generate_image_sdxl(prompt): | |
| blob_name = f"diffusers/generated_image_sdxl:{prompt}" | |
| image_bytes = load_object_from_gcs(blob_name) | |
| if not image_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating SDXL image") as pbar: | |
| image = base( | |
| prompt=prompt, | |
| num_inference_steps=40, | |
| denoising_end=0.8, | |
| output_type="latent", | |
| ).images | |
| image = refiner( | |
| prompt=prompt, | |
| num_inference_steps=40, | |
| denoising_start=0.8, | |
| image=image, | |
| ).images[0] | |
| pbar.update(1) | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| image_bytes = buffered.getvalue() | |
| save_object_to_gcs(blob_name, image_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate SDXL image: {e}") | |
| return None | |
| return image_bytes | |
| def generate_musicgen_melody(prompt): | |
| blob_name = f"music/generated_musicgen_melody:{prompt}" | |
| song_bytes = load_object_from_gcs(blob_name) | |
| if not song_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating MusicGen melody") as pbar: | |
| melody, sr = torchaudio.load("./assets/bach.mp3") | |
| wav = music_gen_melody.generate_with_chroma( | |
| [prompt], melody[None].expand(3, -1, -1), sr | |
| ) | |
| pbar.update(1) | |
| song_bytes = wav[0].getvalue() | |
| save_object_to_gcs(blob_name, song_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate MusicGen melody: {e}") | |
| return None | |
| return song_bytes | |
| def generate_musicgen_large(prompt): | |
| blob_name = f"music/generated_musicgen_large:{prompt}" | |
| song_bytes = load_object_from_gcs(blob_name) | |
| if not song_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating MusicGen large") as pbar: | |
| wav = music_gen_large.generate([prompt]) | |
| pbar.update(1) | |
| song_bytes = wav[0].getvalue() | |
| save_object_to_gcs(blob_name, song_bytes) | |
| except Exception as e: | |
| print(f"Failed to generate MusicGen large: {e}") | |
| return None | |
| return song_bytes | |
| def transcribe_audio(audio_sample): | |
| blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}" | |
| text = load_object_from_gcs(blob_name) | |
| if not text: | |
| try: | |
| with tqdm(total=1, desc="Transcribing audio") as pbar: | |
| text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"] | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, text) | |
| except Exception as e: | |
| print(f"Failed to transcribe audio: {e}") | |
| return None | |
| return text | |
| def generate_mistral_instruct(prompt): | |
| blob_name = f"transformers/generated_mistral_instruct:{prompt}" | |
| response = load_object_from_gcs(blob_name) | |
| if not response: | |
| try: | |
| conversation = [{"role": "user", "content": prompt}] | |
| with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar: | |
| inputs = mistral_instruct_tokenizer.apply_chat_template( | |
| conversation, | |
| tools=tools, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| outputs = mistral_instruct_model.generate( | |
| **inputs, max_new_tokens=1000 | |
| ) | |
| response = mistral_instruct_tokenizer.decode( | |
| outputs[0], skip_special_tokens=True | |
| ) | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, response) | |
| except Exception as e: | |
| print(f"Failed to generate Mistral Instruct response: {e}") | |
| return None | |
| return response | |
| def generate_mistral_nemo(prompt): | |
| blob_name = f"transformers/generated_mistral_nemo:{prompt}" | |
| response = load_object_from_gcs(blob_name) | |
| if not response: | |
| try: | |
| conversation = [{"role": "user", "content": prompt}] | |
| with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar: | |
| inputs = mistral_nemo_tokenizer.apply_chat_template( | |
| conversation, | |
| tools=tools, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000) | |
| response = mistral_nemo_tokenizer.decode( | |
| outputs[0], skip_special_tokens=True | |
| ) | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, response) | |
| except Exception as e: | |
| print(f"Failed to generate Mistral Nemo response: {e}") | |
| return None | |
| return response | |
| def generate_gpt2_xl(prompt): | |
| blob_name = f"transformers/generated_gpt2_xl:{prompt}" | |
| response = load_object_from_gcs(blob_name) | |
| if not response: | |
| try: | |
| with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar: | |
| inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt") | |
| outputs = gpt2_xl_model(**inputs) | |
| response = gpt2_xl_tokenizer.decode( | |
| outputs[0][0], skip_special_tokens=True | |
| ) | |
| pbar.update(1) | |
| save_object_to_gcs(blob_name, response) | |
| except Exception as e: | |
| print(f"Failed to generate GPT-2 XL response: {e}") | |
| return None | |
| return response | |
| def store_user_question(question): | |
| blob_name = "user_questions.txt" | |
| blob = bucket.blob(blob_name) | |
| if blob.exists(): | |
| blob.download_to_filename("user_questions.txt") | |
| with open("user_questions.txt", "a") as f: | |
| f.write(question + "\n") | |
| blob.upload_from_filename("user_questions.txt") | |
| def retrain_models(): | |
| pass | |
| def generate_text_to_video_ms_1_7b(prompt, num_frames=200): | |
| blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}" | |
| video_bytes = load_object_from_gcs(blob_name) | |
| if not video_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating video") as pbar: | |
| video_frames = text_to_video_ms_1_7b_pipeline( | |
| prompt, num_inference_steps=25, num_frames=num_frames | |
| ).frames | |
| pbar.update(1) | |
| video_path = export_to_video(video_frames) | |
| with open(video_path, "rb") as f: | |
| video_bytes = f.read() | |
| save_object_to_gcs(blob_name, video_bytes) | |
| os.remove(video_path) | |
| except Exception as e: | |
| print(f"Failed to generate video: {e}") | |
| return None | |
| return video_bytes | |
| def generate_text_to_video_ms_1_7b_short(prompt): | |
| blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}" | |
| video_bytes = load_object_from_gcs(blob_name) | |
| if not video_bytes: | |
| try: | |
| with tqdm(total=1, desc="Generating short video") as pbar: | |
| video_frames = text_to_video_ms_1_7b_short_pipeline( | |
| prompt, num_inference_steps=25 | |
| ).frames | |
| pbar.update(1) | |
| video_path = export_to_video(video_frames) | |
| with open(video_path, "rb") as f: | |
| video_bytes = f.read() | |
| save_object_to_gcs(blob_name, video_bytes) | |
| os.remove(video_path) | |
| except Exception as e: | |
| print(f"Failed to generate short video: {e}") | |
| return None | |
| return video_bytes | |
| text_to_image_pipeline = get_model_or_download( | |
| "stabilityai/stable-diffusion-2", | |
| "diffusers/text_to_image_model", | |
| StableDiffusionPipeline.from_pretrained, | |
| ) | |
| img2img_pipeline = get_model_or_download( | |
| "CompVis/stable-diffusion-v1-4", | |
| "diffusers/img2img_model", | |
| StableDiffusionImg2ImgPipeline.from_pretrained, | |
| ) | |
| flux_pipeline = get_model_or_download( | |
| "black-forest-labs/FLUX.1-schnell", | |
| "diffusers/flux_model", | |
| FluxPipeline.from_pretrained, | |
| ) | |
| text_gen_pipeline = transformers_pipeline( | |
| "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b" | |
| ) | |
| music_gen = ( | |
| load_object_from_gcs("music/music_gen") | |
| or musicgen.MusicGen.get_pretrained("melody") | |
| ) | |
| meta_llama_pipeline = get_model_or_download( | |
| "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| "transformers/meta_llama_model", | |
| transformers_pipeline, | |
| ) | |
| starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder") | |
| starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder") | |
| base = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ) | |
| refiner = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-refiner-1.0", | |
| text_encoder_2=base.text_encoder_2, | |
| vae=base.vae, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16", | |
| ) | |
| music_gen_melody = musicgen.MusicGen.get_pretrained("melody") | |
| music_gen_melody.set_generation_params(duration=8) | |
| music_gen_large = musicgen.MusicGen.get_pretrained("large") | |
| music_gen_large.set_generation_params(duration=8) | |
| whisper_pipeline = transformers_pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| chunk_length_s=30, | |
| ) | |
| mistral_instruct_model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-Large-Instruct-2407", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| mistral_instruct_tokenizer = AutoTokenizer.from_pretrained( | |
| "mistralai/Mistral-Large-Instruct-2407" | |
| ) | |
| mistral_nemo_model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-Nemo-Instruct-2407", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| mistral_nemo_tokenizer = AutoTokenizer.from_pretrained( | |
| "mistralai/Mistral-Nemo-Instruct-2407" | |
| ) | |
| gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl") | |
| gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl") | |
| llama_3_groq_70b_tool_use_pipeline = transformers_pipeline( | |
| "text-generation", model="Groq/Llama-3-Groq-70B-Tool-Use" | |
| ) | |
| phi_3_5_mini_instruct_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Phi-3.5-mini-instruct", torch_dtype="auto", trust_remote_code=True | |
| ) | |
| phi_3_5_mini_instruct_tokenizer = AutoTokenizer.from_pretrained( | |
| "microsoft/Phi-3.5-mini-instruct" | |
| ) | |
| phi_3_5_mini_instruct_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model=phi_3_5_mini_instruct_model, | |
| tokenizer=phi_3_5_mini_instruct_tokenizer, | |
| ) | |
| meta_llama_3_1_8b_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model="meta-llama/Meta-Llama-3.1-8B", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| ) | |
| meta_llama_3_1_70b_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model="meta-llama/Meta-Llama-3.1-70B", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| ) | |
| medical_text_summarization_pipeline = transformers_pipeline( | |
| "summarization", model="your/medical_text_summarization_model" | |
| ) | |
| bart_large_cnn_summarization_pipeline = transformers_pipeline( | |
| "summarization", model="facebook/bart-large-cnn" | |
| ) | |
| flux_1_dev_pipeline = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 | |
| ) | |
| flux_1_dev_pipeline.enable_model_cpu_offload() | |
| gemma_2_9b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b") | |
| gemma_2_9b_it_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model="google/gemma-2-9b-it", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| ) | |
| gemma_2_2b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-2b") | |
| gemma_2_2b_it_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model="google/gemma-2-2b-it", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| ) | |
| gemma_2_27b_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b") | |
| gemma_2_27b_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b") | |
| gemma_2_27b_it_pipeline = transformers_pipeline( | |
| "text-generation", | |
| model="google/gemma-2-27b-it", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| ) | |
| text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained( | |
| "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| text_to_video_ms_1_7b_pipeline.scheduler.config | |
| ) | |
| text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload() | |
| text_to_video_ms_1_7b_pipeline.enable_vae_slicing() | |
| text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained( | |
| "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| text_to_video_ms_1_7b_short_pipeline.scheduler = ( | |
| DPMSolverMultistepScheduler.from_config( | |
| text_to_video_ms_1_7b_short_pipeline.scheduler.config | |
| ) | |
| ) | |
| text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload() | |
| tools = [] | |
| gen_image_tab = gr.Interface( | |
| fn=generate_image, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Image(type="pil"), | |
| title="Generate Image", | |
| ) | |
| edit_image_tab = gr.Interface( | |
| fn=edit_image_with_prompt, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image:"), | |
| gr.Textbox(label="Prompt:"), | |
| gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"), | |
| ], | |
| outputs=gr.Image(type="pil"), | |
| title="Edit Image", | |
| ) | |
| generate_song_tab = gr.Interface( | |
| fn=generate_song, | |
| inputs=[ | |
| gr.Textbox(label="Prompt:"), | |
| gr.Slider(5, 60, 10, step=1, label="Duration (s):"), | |
| ], | |
| outputs=gr.Audio(type="numpy"), | |
| title="Generate Songs", | |
| ) | |
| generate_text_tab = gr.Interface( | |
| fn=generate_text, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Textbox(label="Generated Text:"), | |
| title="Generate Text", | |
| ) | |
| generate_flux_image_tab = gr.Interface( | |
| fn=generate_flux_image, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Image(type="pil"), | |
| title="Generate FLUX Images", | |
| ) | |
| generate_code_tab = gr.Interface( | |
| fn=generate_code, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Textbox(label="Generated Code:"), | |
| title="Generate Code", | |
| ) | |
| model_meta_llama_test_tab = gr.Interface( | |
| fn=test_model_meta_llama, | |
| inputs=None, | |
| outputs=gr.Textbox(label="Model Output:"), | |
| title="Test Meta-Llama", | |
| ) | |
| generate_image_sdxl_tab = gr.Interface( | |
| fn=generate_image_sdxl, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Image(type="pil"), | |
| title="Generate SDXL Image", | |
| ) | |
| generate_musicgen_melody_tab = gr.Interface( | |
| fn=generate_musicgen_melody, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Audio(type="numpy"), | |
| title="Generate MusicGen Melody", | |
| ) | |
| generate_musicgen_large_tab = gr.Interface( | |
| fn=generate_musicgen_large, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Audio(type="numpy"), | |
| title="Generate MusicGen Large", | |
| ) | |
| transcribe_audio_tab = gr.Interface( | |
| fn=transcribe_audio, | |
| inputs=gr.Audio(type="numpy", label="Audio Sample:"), | |
| outputs=gr.Textbox(label="Transcribed Text:"), | |
| title="Transcribe Audio", | |
| ) | |
| generate_mistral_instruct_tab = gr.Interface( | |
| fn=generate_mistral_instruct, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Textbox(label="Mistral Instruct Response:"), | |
| title="Generate Mistral Instruct Response", | |
| ) | |
| generate_mistral_nemo_tab = gr.Interface( | |
| fn=generate_mistral_nemo, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Textbox(label="Mistral Nemo Response:"), | |
| title="Generate Mistral Nemo Response", | |
| ) | |
| generate_gpt2_xl_tab = gr.Interface( | |
| fn=generate_gpt2_xl, | |
| inputs=gr.Textbox(label="Prompt:"), | |
| outputs=gr.Textbox(label="GPT-2 XL Response:"), | |
| title="Generate GPT-2 XL Response", | |
| ) | |
| answer_question_minicpm_tab = gr.Interface( | |
| fn=answer_question_minicpm, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image:"), | |
| gr.Textbox(label="Question:"), | |
| ], | |
| outputs=gr.Textbox(label="MiniCPM Answer:"), | |
| title="Answer Question with MiniCPM", | |
| ) | |
| llama_3_groq_70b_tool_use_tab = gr.Interface( | |
| fn=llama_3_groq_70b_tool_use_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Llama 3 Groq 70B Tool Use Response:"), | |
| title="Llama 3 Groq 70B Tool Use", | |
| ) | |
| phi_3_5_mini_instruct_tab = gr.Interface( | |
| fn=phi_3_5_mini_instruct_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Phi 3.5 Mini Instruct Response:"), | |
| title="Phi 3.5 Mini Instruct", | |
| ) | |
| meta_llama_3_1_8b_tab = gr.Interface( | |
| fn=meta_llama_3_1_8b_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Meta Llama 3.1 8B Response:"), | |
| title="Meta Llama 3.1 8B", | |
| ) | |
| meta_llama_3_1_70b_tab = gr.Interface( | |
| fn=meta_llama_3_1_70b_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Meta Llama 3.1 70B Response:"), | |
| title="Meta Llama 3.1 70B", | |
| ) | |
| medical_text_summarization_tab = gr.Interface( | |
| fn=medical_text_summarization_pipeline, | |
| inputs=[gr.Textbox(label="Medical Document:")], | |
| outputs=gr.Textbox(label="Medical Text Summarization:"), | |
| title="Medical Text Summarization", | |
| ) | |
| bart_large_cnn_summarization_tab = gr.Interface( | |
| fn=bart_large_cnn_summarization_pipeline, | |
| inputs=[gr.Textbox(label="Article:")], | |
| outputs=gr.Textbox(label="Bart Large CNN Summarization:"), | |
| title="Bart Large CNN Summarization", | |
| ) | |
| flux_1_dev_tab = gr.Interface( | |
| fn=flux_1_dev_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Image(type="pil"), | |
| title="FLUX 1 Dev", | |
| ) | |
| gemma_2_9b_tab = gr.Interface( | |
| fn=gemma_2_9b_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 9B Response:"), | |
| title="Gemma 2 9B", | |
| ) | |
| gemma_2_9b_it_tab = gr.Interface( | |
| fn=gemma_2_9b_it_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 9B IT Response:"), | |
| title="Gemma 2 9B IT", | |
| ) | |
| gemma_2_2b_tab = gr.Interface( | |
| fn=gemma_2_2b_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 2B Response:"), | |
| title="Gemma 2 2B", | |
| ) | |
| gemma_2_2b_it_tab = gr.Interface( | |
| fn=gemma_2_2b_it_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 2B IT Response:"), | |
| title="Gemma 2 2B IT", | |
| ) | |
| def generate_gemma_2_27b(prompt): | |
| input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt") | |
| outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32) | |
| return gemma_2_27b_tokenizer.decode(outputs[0]) | |
| gemma_2_27b_tab = gr.Interface( | |
| fn=generate_gemma_2_27b, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 27B Response:"), | |
| title="Gemma 2 27B", | |
| ) | |
| gemma_2_27b_it_tab = gr.Interface( | |
| fn=gemma_2_27b_it_pipeline, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Textbox(label="Gemma 2 27B IT Response:"), | |
| title="Gemma 2 27B IT", | |
| ) | |
| text_to_video_ms_1_7b_tab = gr.Interface( | |
| fn=generate_text_to_video_ms_1_7b, | |
| inputs=[ | |
| gr.Textbox(label="Prompt:"), | |
| gr.Slider(50, 200, 200, step=1, label="Number of Frames:"), | |
| ], | |
| outputs=gr.Video(), | |
| title="Text to Video MS 1.7B", | |
| ) | |
| text_to_video_ms_1_7b_short_tab = gr.Interface( | |
| fn=generate_text_to_video_ms_1_7b_short, | |
| inputs=[gr.Textbox(label="Prompt:")], | |
| outputs=gr.Video(), | |
| title="Text to Video MS 1.7B Short", | |
| ) | |
| app = gr.TabbedInterface( | |
| [ | |
| gen_image_tab, | |
| edit_image_tab, | |
| generate_song_tab, | |
| generate_text_tab, | |
| generate_flux_image_tab, | |
| generate_code_tab, | |
| model_meta_llama_test_tab, | |
| generate_image_sdxl_tab, | |
| generate_musicgen_melody_tab, | |
| generate_musicgen_large_tab, | |
| transcribe_audio_tab, | |
| generate_mistral_instruct_tab, | |
| generate_mistral_nemo_tab, | |
| generate_gpt2_xl_tab, | |
| llama_3_groq_70b_tool_use_tab, | |
| phi_3_5_mini_instruct_tab, | |
| meta_llama_3_1_8b_tab, | |
| meta_llama_3_1_70b_tab, | |
| medical_text_summarization_tab, | |
| bart_large_cnn_summarization_tab, | |
| flux_1_dev_tab, | |
| gemma_2_9b_tab, | |
| gemma_2_9b_it_tab, | |
| gemma_2_2b_tab, | |
| gemma_2_2b_it_tab, | |
| gemma_2_27b_tab, | |
| gemma_2_27b_it_tab, | |
| text_to_video_ms_1_7b_tab, | |
| text_to_video_ms_1_7b_short_tab, | |
| ], | |
| [ | |
| "Generate Image", | |
| "Edit Image", | |
| "Generate Song", | |
| "Generate Text", | |
| "Generate FLUX Image", | |
| "Generate Code", | |
| "Test Meta-Llama", | |
| "Generate SDXL Image", | |
| "Generate MusicGen Melody", | |
| "Generate MusicGen Large", | |
| "Transcribe Audio", | |
| "Generate Mistral Instruct Response", | |
| "Generate Mistral Nemo Response", | |
| "Generate GPT-2 XL Response", | |
| "Llama 3 Groq 70B Tool Use", | |
| "Phi 3.5 Mini Instruct", | |
| "Meta Llama 3.1 8B", | |
| "Meta Llama 3.1 70B", | |
| "Medical Text Summarization", | |
| "Bart Large CNN Summarization", | |
| "FLUX 1 Dev", | |
| "Gemma 2 9B", | |
| "Gemma 2 9B IT", | |
| "Gemma 2 2B", | |
| "Gemma 2 2B IT", | |
| "Gemma 2 27B", | |
| "Gemma 2 27B IT", | |
| "Text to Video MS 1.7B", | |
| "Text to Video MS 1.7B Short", | |
| ], | |
| ) | |
| app.launch(share=True) |