Spaces:
Running
Running
| import os | |
| from typing import Any | |
| from enum import Enum | |
| from smolagents import InferenceClientModel, OpenAIServerModel # type: ignore | |
| from .models import LLMProviderType | |
| ENV_KEY_MAP = { | |
| LLMProviderType.OPENAI: "OPENAI_API_KEY", | |
| LLMProviderType.GEMINI: "GEMINI_API_KEY", | |
| LLMProviderType.CLAUDE: "CLAUDE_API_KEY", | |
| LLMProviderType.HF: "HF_API_KEY", | |
| LLMProviderType.OPENROUTER: "OPENROUTER_API_KEY", | |
| } | |
| # Base URLs for OpenAI-compatible API providers | |
| BASE_URL_MAP = { | |
| LLMProviderType.OPENAI: "https://api.openai.com/v1", | |
| LLMProviderType.OPENROUTER: "https://openrouter.ai/api/v1", | |
| LLMProviderType.CLAUDE: "https://api.anthropic.com/v1", | |
| LLMProviderType.GEMINI: "https://generativelanguage.googleapis.com/v1beta/openai/", | |
| } | |
| class LLMProvider: | |
| def __init__(self, provider: LLMProviderType, model_id: str): | |
| """ | |
| provider: LLMProviderType enum value. | |
| model: model name string for the provider. | |
| If not provided, raises ValueError. | |
| """ | |
| if provider is None: | |
| raise ValueError("LLMProvider requires a provider argument or LLM_PROVIDER env variable.") | |
| else: | |
| self.provider = provider | |
| if model_id is None: | |
| raise ValueError("LLMProvider requires a model argument.") | |
| else: | |
| self.model_id = model_id | |
| key_env = ENV_KEY_MAP.get(self.provider) | |
| self.api_key = os.getenv(key_env) if key_env else None | |
| def get_model(self, **kwargs: Any) -> Any: | |
| """Return a model client for the selected provider and model.""" | |
| if not self.api_key: | |
| raise ValueError(f"API key for provider {self.provider} not found in environment.") | |
| if not self.model_id: | |
| raise ValueError("Model name must be provided.") | |
| if self.provider == LLMProviderType.HF: | |
| return InferenceClientModel( | |
| model_id=self.model_id, | |
| token=self.api_key, | |
| timeout=300 | |
| ) | |
| elif self.provider in [ | |
| LLMProviderType.OPENAI, | |
| LLMProviderType.OPENROUTER, | |
| LLMProviderType.CLAUDE, | |
| LLMProviderType.GEMINI, | |
| ]: | |
| api_base = BASE_URL_MAP.get(self.provider) | |
| if not api_base: | |
| raise ValueError(f"Base URL not configured for provider: {self.provider}") | |
| return OpenAIServerModel( | |
| model_id=self.model_id, | |
| api_key=self.api_key, | |
| api_base=api_base, | |
| timeout=300 | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {self.provider}") | |