File size: 2,657 Bytes
66d6356
 
 
d6e5a14
9f80f96
66d6356
 
 
 
 
 
 
 
 
 
 
 
 
 
4baacf3
66d6356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e5a14
66d6356
 
 
 
 
 
 
 
 
5099f8f
 
66d6356
21a0c2b
 
 
 
 
 
66d6356
 
 
 
 
 
5099f8f
 
66d6356
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from typing import Any
from enum import Enum
from smolagents import InferenceClientModel, OpenAIServerModel  # type: ignore
from .models import LLMProviderType

ENV_KEY_MAP = {
    LLMProviderType.OPENAI: "OPENAI_API_KEY",
    LLMProviderType.GEMINI: "GEMINI_API_KEY",
    LLMProviderType.CLAUDE: "CLAUDE_API_KEY",
    LLMProviderType.HF: "HF_API_KEY",
    LLMProviderType.OPENROUTER: "OPENROUTER_API_KEY",
}

# Base URLs for OpenAI-compatible API providers
BASE_URL_MAP = {
    LLMProviderType.OPENAI: "https://api.openai.com/v1",
    LLMProviderType.OPENROUTER: "https://openrouter.ai/api/v1",
    LLMProviderType.CLAUDE: "https://api.anthropic.com/v1",
    LLMProviderType.GEMINI: "https://generativelanguage.googleapis.com/v1beta/openai/",
}


class LLMProvider:
    def __init__(self, provider: LLMProviderType, model_id: str):
        """
        provider: LLMProviderType enum value.
        model: model name string for the provider.
        If not provided, raises ValueError.
        """
        if provider is None:
            raise ValueError("LLMProvider requires a provider argument or LLM_PROVIDER env variable.")
        else:
            self.provider = provider
        if model_id is None:
            raise ValueError("LLMProvider requires a model argument.")
        else: 
            self.model_id = model_id
        key_env = ENV_KEY_MAP.get(self.provider)
        self.api_key = os.getenv(key_env) if key_env else None

        
    def get_model(self, **kwargs: Any) -> Any:
        """Return a model client for the selected provider and model."""
        if not self.api_key:
            raise ValueError(f"API key for provider {self.provider} not found in environment.")
        if not self.model_id:
            raise ValueError("Model name must be provided.")
        
        if self.provider == LLMProviderType.HF:
            return InferenceClientModel(
                model_id=self.model_id,
                token=self.api_key,
                timeout=300
            )
        elif self.provider in [
            LLMProviderType.OPENAI,
            LLMProviderType.OPENROUTER,
            LLMProviderType.CLAUDE,
            LLMProviderType.GEMINI,
        ]:
            api_base = BASE_URL_MAP.get(self.provider)
            if not api_base:
                raise ValueError(f"Base URL not configured for provider: {self.provider}")
            return OpenAIServerModel(
                model_id=self.model_id,
                api_key=self.api_key,
                api_base=api_base,
                timeout=300
            )
        else:
            raise ValueError(f"Unknown provider: {self.provider}")