EastSync-AI / LLM /llm_provider.py
StanSava's picture
update inference timemout
5099f8f
import os
from typing import Any
from enum import Enum
from smolagents import InferenceClientModel, OpenAIServerModel # type: ignore
from .models import LLMProviderType
ENV_KEY_MAP = {
LLMProviderType.OPENAI: "OPENAI_API_KEY",
LLMProviderType.GEMINI: "GEMINI_API_KEY",
LLMProviderType.CLAUDE: "CLAUDE_API_KEY",
LLMProviderType.HF: "HF_API_KEY",
LLMProviderType.OPENROUTER: "OPENROUTER_API_KEY",
}
# Base URLs for OpenAI-compatible API providers
BASE_URL_MAP = {
LLMProviderType.OPENAI: "https://api.openai.com/v1",
LLMProviderType.OPENROUTER: "https://openrouter.ai/api/v1",
LLMProviderType.CLAUDE: "https://api.anthropic.com/v1",
LLMProviderType.GEMINI: "https://generativelanguage.googleapis.com/v1beta/openai/",
}
class LLMProvider:
def __init__(self, provider: LLMProviderType, model_id: str):
"""
provider: LLMProviderType enum value.
model: model name string for the provider.
If not provided, raises ValueError.
"""
if provider is None:
raise ValueError("LLMProvider requires a provider argument or LLM_PROVIDER env variable.")
else:
self.provider = provider
if model_id is None:
raise ValueError("LLMProvider requires a model argument.")
else:
self.model_id = model_id
key_env = ENV_KEY_MAP.get(self.provider)
self.api_key = os.getenv(key_env) if key_env else None
def get_model(self, **kwargs: Any) -> Any:
"""Return a model client for the selected provider and model."""
if not self.api_key:
raise ValueError(f"API key for provider {self.provider} not found in environment.")
if not self.model_id:
raise ValueError("Model name must be provided.")
if self.provider == LLMProviderType.HF:
return InferenceClientModel(
model_id=self.model_id,
token=self.api_key,
timeout=300
)
elif self.provider in [
LLMProviderType.OPENAI,
LLMProviderType.OPENROUTER,
LLMProviderType.CLAUDE,
LLMProviderType.GEMINI,
]:
api_base = BASE_URL_MAP.get(self.provider)
if not api_base:
raise ValueError(f"Base URL not configured for provider: {self.provider}")
return OpenAIServerModel(
model_id=self.model_id,
api_key=self.api_key,
api_base=api_base,
timeout=300
)
else:
raise ValueError(f"Unknown provider: {self.provider}")