Spaces:
Sleeping
Sleeping
| # This wrapper provides several features: | |
| # 1. A `ImageGenerationParams` dataclass to handle parameters with default values | |
| # 2. A `ImageGenerationResult` class to wrap the API response | |
| # 3. The main `ImagenWrapper` class with: | |
| # - Proper initialization with error handling | |
| # - Logging support | |
| # - Two methods for generation: | |
| # - `generate()` using the `ImageGenerationParams` class | |
| # - `generate_simple()` for a more straightforward interface | |
| # Here's how to use it: | |
| # # Example usage: | |
| # # Initialize the wrapper | |
| # wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/") | |
| # # Method 1: Using ImageGenerationParams | |
| # params = ImageGenerationParams( | |
| # prompt="A beautiful sunset over mountains", | |
| # width=512, | |
| # height=512 | |
| # ) | |
| # result = wrapper.generate(params) | |
| # # Method 2: Using generate_simple | |
| # result = wrapper.generate_simple( | |
| # prompt="A beautiful sunset over mountains", | |
| # width=512, | |
| # height=512 | |
| # ) | |
| # # Access the results | |
| # print(f"Image URL: {result.image_url}") | |
| # print(f"Seed used: {result.seed}") | |
| # The wrapper includes: | |
| # - Type hints for better IDE support | |
| # - Error handling and logging | |
| # - Parameter validation | |
| # - Flexible parameter input (both through dataclass and dictionary) | |
| # - Clean result handling through a dedicated class | |
| # You can also add error handling in your code: | |
| # try: | |
| # wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/") | |
| # result = wrapper.generate_simple("A beautiful sunset") | |
| # print(f"Generated image: {result}") | |
| # except ConnectionError as e: | |
| # print(f"Failed to connect to API: {e}") | |
| # except RuntimeError as e: | |
| # print(f"Generation failed: {e}") | |
| # except Exception as e: | |
| # print(f"Unexpected error: {e}") | |
| from gradio_client import Client | |
| from typing import Dict, Tuple, Optional, Union | |
| from dataclasses import dataclass | |
| import logging | |
| class ImageGenerationParams: | |
| """Data class to hold image generation parameters""" | |
| prompt: str | |
| seed: float = 0 | |
| randomize_seed: bool = True | |
| width: float = 1024 | |
| height: float = 1024 | |
| guidance_scale: float = 3.5 | |
| num_inference_steps: float = 28 | |
| lora_scale: float = 0.7 | |
| class ImageGenerationResult: | |
| """Class to handle the generation result""" | |
| def __init__(self, image_data: Dict, seed: float): | |
| self.image_path = image_data.get('path') | |
| self.image_url = image_data.get('url') | |
| self.size = image_data.get('size') | |
| self.orig_name = image_data.get('orig_name') | |
| self.mime_type = image_data.get('mime_type') | |
| self.is_stream = image_data.get('is_stream', False) | |
| self.meta = image_data.get('meta', {}) | |
| self.seed = seed | |
| def __str__(self) -> str: | |
| return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})" | |
| class ImagenWrapper: | |
| """Wrapper class for the Imagen Gradio deployment""" | |
| def __init__(self, api_url: str): | |
| """ | |
| Initialize the wrapper with the API URL | |
| Args: | |
| api_url (str): The URL of the Gradio deployment | |
| """ | |
| self.api_url = api_url | |
| self.logger = logging.getLogger(__name__) | |
| try: | |
| self.client = Client(api_url) | |
| self.logger.info(f"Successfully connected to API at {api_url}") | |
| except Exception as e: | |
| self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}") | |
| raise ConnectionError(f"Failed to connect to API: {str(e)}") | |
| def generate(self, | |
| params: Union[ImageGenerationParams, Dict], | |
| ) -> ImageGenerationResult: | |
| """ | |
| Generate an image using the provided parameters | |
| Args: | |
| params: Either an ImageGenerationParams object or a dictionary with the parameters | |
| Returns: | |
| ImageGenerationResult: Object containing the generation results | |
| Raises: | |
| ValueError: If parameters are invalid | |
| RuntimeError: If the API call fails | |
| """ | |
| try: | |
| # Convert dict to ImageGenerationParams if necessary | |
| if isinstance(params, dict): | |
| params = ImageGenerationParams(**params) | |
| # Validate parameters | |
| if not params.prompt: | |
| raise ValueError("Prompt cannot be empty") | |
| # Make the API call | |
| result = self.client.predict( | |
| prompt=params.prompt, | |
| seed=params.seed, | |
| randomize_seed=params.randomize_seed, | |
| width=params.width, | |
| height=params.height, | |
| guidance_scale=params.guidance_scale, | |
| num_inference_steps=params.num_inference_steps, | |
| lora_scale=params.lora_scale, | |
| api_name="/infer" | |
| ) | |
| # Process the result | |
| if not result or len(result) != 2: | |
| raise RuntimeError("Invalid response from API") | |
| image_data, seed = result | |
| return ImageGenerationResult(image_data, seed) | |
| except Exception as e: | |
| self.logger.error(f"Error during image generation: {str(e)}") | |
| raise RuntimeError(f"Failed to generate image: {str(e)}") | |
| def generate_simple(self, | |
| prompt: str, | |
| **kwargs) -> ImageGenerationResult: | |
| """ | |
| Simplified interface for generating images | |
| Args: | |
| prompt (str): The prompt for image generation | |
| **kwargs: Optional parameters to override defaults | |
| Returns: | |
| ImageGenerationResult: Object containing the generation results | |
| """ | |
| params = ImageGenerationParams(prompt=prompt, **kwargs) | |
| return self.generate(params) |