File size: 4,210 Bytes
92b5988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b20ff
92b5988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a24fe
 
92b5988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a24fe
92b5988
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from google import genai
from google.genai import types
import os

# Global variable to hold the client
client = None

def initialize():
    """
    Initializes the Google Generative AI client.
    """
    global client
    # It's a good practice to load the API key from an environment variable
    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        api_key = os.environ.get("GOOGLE_API_KEY")
    
    if not api_key:
        raise ValueError("Neither GEMINI_API_KEY nor GOOGLE_API_KEY environment variable is set.")
    
    try:
        client = genai.Client(api_key=api_key)
        print("Google Generative AI client initialized.")
    except Exception as e:
        print(f"Error initializing Google Generative AI client: {e}")
        raise

def generate_content(prompt: str, model_name: str = None, allow_fallbacks: bool = True, generation_config: dict = None) -> str:
    """
    Generates content using the Google Generative AI model.

    Args:
        prompt: The prompt to send to the model.
        model_name: The name of the model to use (e.g., "gemini-2.0-flash", "gemini-2.5-flash").
                    If None, a default model will be used.
        allow_fallbacks: (Currently not directly used by genai.Client.models.generate_content,
                         but kept for compatibility with agent.py structure)
        generation_config: A dictionary for generation parameters like temperature, max_output_tokens.

    Returns:
        The generated text content.
    """
    global client
    if client is None:
        # Attempt to initialize if not already done, though ideally initialize() should be called explicitly.
        print("Client not initialized. Attempting to initialize now...")
        initialize()
        if client is None: # Check again after attempt
             raise RuntimeError("Google Generative AI client is not initialized. Call initialize() first.")

    # Default model if not specified - using gemini-2.5-flash-lite (fastest, cost-efficient)
    effective_model_name = model_name if model_name else "gemini-2.5-flash-lite"

    # Prepare generation configuration for the API
    config_obj = None
    if generation_config:
        config_params = {}
        if 'temperature' in generation_config:
            config_params['temperature'] = generation_config['temperature']
        if 'max_output_tokens' in generation_config:
            config_params['max_output_tokens'] = generation_config['max_output_tokens']
        # Add other relevant parameters from generation_config as needed by the genai API
        
        if config_params:
            config_obj = types.GenerateContentConfig(**config_params)

    try:
        response = client.models.generate_content(
            model=effective_model_name,
            contents=[prompt],  # Note: contents expects a list
            config=config_obj
        )
        return response.text
    except Exception as e:
        print(f"Error during content generation: {e}")
        # Depending on how agent.py handles errors, you might want to raise the exception
        # or return a specific error message. For now, re-raising.
        raise

if __name__ == '__main__':
    # Example usage (optional, for testing inference.py directly)
    try:
        # Make sure to set your GEMINI_API_KEY environment variable before running
        # For example, in your terminal: $env:GEMINI_API_KEY="YOUR_API_KEY"
        initialize()
        if client:
            sample_prompt = "Explain how AI works in a few words"
            print(f"Sending prompt: '{sample_prompt}'")
            config = {'temperature': 0.7, 'max_output_tokens': 50}
            generated_text = generate_content(sample_prompt, generation_config=config)
            print("\nGenerated text:")
            print(generated_text)

            sample_prompt_2 = "What is the capital of France?"
            print(f"\nSending prompt: '{sample_prompt_2}'")
            generated_text_2 = generate_content(sample_prompt_2, model_name="gemini-2.5-flash-lite") # Using fastest model
            print("\nGenerated text:")
            print(generated_text_2)
    except Exception as e:
        print(f"An error occurred: {e}")