NLarchive commited on
Commit
39b087d
·
verified ·
1 Parent(s): 70266e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +904 -0
app.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import threading
5
+ import time
6
+ import re
7
+ import atexit
8
+ from contextlib import asynccontextmanager
9
+ from typing import Any, Optional, List, Dict, Tuple, Callable # Enhanced typing imports
10
+
11
+ # Add parent dir so we can import inference.py
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+
14
+ from smolagents import CodeAgent, MCPClient
15
+ from smolagents.models import Model
16
+ from inference import initialize, generate_content
17
+ from workflow_vizualizer import track_workflow_step, track_communication, complete_workflow_step
18
+
19
+ # Global session management
20
+ _session_initialized = False
21
+ _session_lock = threading.Lock()
22
+ _session_start_time = None
23
+
24
+ # Enhanced global caching for Phase 2 optimizations with async support
25
+ _global_tools_cache = {}
26
+ _global_tools_timestamp = None
27
+ _global_model_instance = None
28
+ _global_model_lock = threading.Lock()
29
+ _global_connection_pool = {}
30
+ _global_connection_lock = threading.Lock()
31
+
32
+ # Managed event loop system
33
+ _managed_event_loop = None
34
+ _event_loop_lock = threading.Lock()
35
+ _event_loop_manager = None # Global event loop manager instance
36
+
37
+
38
+ @asynccontextmanager
39
+ async def managed_event_loop():
40
+ """Proper async context manager for event loop lifecycle."""
41
+ global _managed_event_loop
42
+
43
+ try:
44
+ # Create new event loop if needed
45
+ if _managed_event_loop is None or _managed_event_loop.is_closed():
46
+ _managed_event_loop = asyncio.new_event_loop()
47
+ asyncio.set_event_loop(_managed_event_loop)
48
+
49
+ print("✅ Event loop initialized and set as current")
50
+ yield _managed_event_loop
51
+
52
+ except Exception as e:
53
+ print(f"❌ Event loop error: {e}")
54
+ raise
55
+ finally:
56
+ # Don't close the loop here - let it be managed at a higher level
57
+ pass
58
+
59
+
60
+ async def safe_async_call(coroutine, timeout=30):
61
+ """Safely execute async calls with proper error handling."""
62
+ try:
63
+ return await asyncio.wait_for(coroutine, timeout=timeout)
64
+ except asyncio.TimeoutError:
65
+ print(f"⏱️ Async call timed out after {timeout}s")
66
+ raise
67
+ except RuntimeError as e:
68
+ if "Event loop is closed" in str(e):
69
+ print("🔄 Event loop closed - attempting to create new one")
70
+ # Create new event loop and retry
71
+ loop = asyncio.new_event_loop()
72
+ asyncio.set_event_loop(loop)
73
+ return await asyncio.wait_for(coroutine, timeout=timeout)
74
+ raise
75
+
76
+
77
+ class AsyncEventLoopManager:
78
+ def __init__(self):
79
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
80
+ self._thread: Optional[threading.Thread] = None
81
+ self._loop = asyncio.new_event_loop()
82
+ self._thread = threading.Thread(target=self._run_loop, daemon=True)
83
+ self._thread.start()
84
+ print("AsyncEventLoopManager: Initialized and thread started.")
85
+
86
+ def _run_loop(self):
87
+ if self._loop is None:
88
+ print("AsyncEventLoopManager: _run_loop called but loop is None.")
89
+ return
90
+ asyncio.set_event_loop(self._loop)
91
+ try:
92
+ print("AsyncEventLoopManager: Event loop running.")
93
+ self._loop.run_forever()
94
+ except Exception as e:
95
+ print(f"AsyncEventLoopManager: Exception in event loop: {e}")
96
+ finally:
97
+ # Ensure the loop is stopped if it was running.
98
+ # The actual closing is handled by the shutdown() method.
99
+ if self._loop and self._loop.is_running():
100
+ self._loop.stop()
101
+ print("AsyncEventLoopManager: Event loop stopped in _run_loop finally.")
102
+
103
+ def run_async(self, coro):
104
+ """Run a coroutine in the event loop from another thread."""
105
+ coro_name = getattr(coro, '__name__', str(coro))
106
+ if self._loop is None:
107
+ print(f"AsyncEventLoopManager: Loop object is None. Cannot run coroutine {coro_name}.")
108
+ raise RuntimeError("Event loop manager is not properly initialized (loop missing).")
109
+
110
+ if self._loop.is_closed():
111
+ print(f"AsyncEventLoopManager: Loop is CLOSED. Cannot schedule coroutine {coro_name}.")
112
+ raise RuntimeError(f"Event loop is closed. Cannot run {coro_name}.")
113
+
114
+ if self._thread is None or not self._thread.is_alive():
115
+ print(f"AsyncEventLoopManager: Event loop thread is not alive or None. Cannot run coroutine {coro_name}.")
116
+ raise RuntimeError("Event loop thread is not alive or None.")
117
+
118
+ try:
119
+ future = asyncio.run_coroutine_threadsafe(coro, self._loop)
120
+ return future.result(timeout=30) # Assuming a 30s timeout
121
+ except RuntimeError as e:
122
+ print(f"AsyncEventLoopManager: RuntimeError during run_coroutine_threadsafe for {coro_name}: {e}")
123
+ raise
124
+ except asyncio.TimeoutError:
125
+ print(f"AsyncEventLoopManager: Timeout waiting for coroutine {coro_name} result.")
126
+ raise
127
+ except Exception as e:
128
+ print(f"AsyncEventLoopManager: Error submitting coroutine {coro_name}: {e}")
129
+ raise
130
+
131
+ def shutdown(self):
132
+ """Stop and close the event loop."""
133
+ print("AsyncEventLoopManager: Shutdown initiated.")
134
+ if self._loop and not self._loop.is_closed():
135
+ if self._loop.is_running():
136
+ self._loop.call_soon_threadsafe(self._loop.stop)
137
+ print("AsyncEventLoopManager: Stop signal sent to running event loop.")
138
+ else:
139
+ print("AsyncEventLoopManager: Event loop was not running, but attempting to stop.")
140
+ # If not running, stop might not be necessary or might error,
141
+ # but call_soon_threadsafe should be safe.
142
+ try:
143
+ self._loop.call_soon_threadsafe(self._loop.stop)
144
+ except RuntimeError as e:
145
+ print(f"AsyncEventLoopManager: Info - could not send stop to non-running loop: {e}")
146
+
147
+ if self._thread and self._thread.is_alive():
148
+ self._thread.join(timeout=10)
149
+ if self._thread.is_alive():
150
+ print("AsyncEventLoopManager: Thread did not join in time during shutdown.")
151
+ else:
152
+ print("AsyncEventLoopManager: Thread joined.")
153
+ else:
154
+ print("AsyncEventLoopManager: Thread already stopped, not initialized, or None at shutdown.")
155
+
156
+ # Explicitly close the loop here after the thread has finished.
157
+ if self._loop and not self._loop.is_closed():
158
+ try:
159
+ # Ensure all tasks are cancelled before closing
160
+ # Gather all tasks:
161
+ if sys.version_info >= (3, 7): # gather works on all tasks in 3.7+
162
+ tasks = asyncio.all_tasks(self._loop)
163
+ for task in tasks:
164
+ task.cancel()
165
+ # Wait for tasks to cancel - this should be done within the loop's thread ideally
166
+ # but since we are shutting down from outside, this is a best effort.
167
+ # self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
168
+ self._loop.close()
169
+ print("AsyncEventLoopManager: Event loop closed in shutdown.")
170
+ except Exception as e:
171
+ print(f"AsyncEventLoopManager: Exception while closing loop: {e}")
172
+ elif self._loop and self._loop.is_closed():
173
+ print("AsyncEventLoopManager: Event loop was already closed.")
174
+ else:
175
+ print("AsyncEventLoopManager: No loop to close or loop was None.")
176
+
177
+ self._loop = None
178
+ self._thread = None
179
+ print("AsyncEventLoopManager: Shutdown process complete.")
180
+
181
+ def get_event_loop_manager():
182
+ """Get or create the global event loop manager."""
183
+ global _event_loop_manager
184
+
185
+ with _event_loop_lock:
186
+ # Check if manager exists and its loop and thread are valid
187
+ manager_valid = False
188
+ if _event_loop_manager is not None:
189
+ # More robust check: loop exists, is not closed, thread exists and is alive
190
+ if _event_loop_manager._loop is not None and \
191
+ not _event_loop_manager._loop.is_closed() and \
192
+ _event_loop_manager._thread is not None and \
193
+ _event_loop_manager._thread.is_alive():
194
+ manager_valid = True
195
+ else:
196
+ print("get_event_loop_manager: Existing manager found but its loop or thread is invalid. Recreating.")
197
+ try:
198
+ _event_loop_manager.shutdown() # Attempt to clean up the old one
199
+ except Exception as e:
200
+ print(f"get_event_loop_manager: Error shutting down invalid manager: {e}")
201
+ _event_loop_manager = None # Ensure it's None so a new one is created
202
+
203
+ if _event_loop_manager is None: # Covers both initial creation and recreation
204
+ print("get_event_loop_manager: Creating new AsyncEventLoopManager instance.")
205
+ _event_loop_manager = AsyncEventLoopManager()
206
+ else:
207
+ print("get_event_loop_manager: Reusing existing valid AsyncEventLoopManager instance.")
208
+ return _event_loop_manager
209
+
210
+ def shutdown_event_loop_manager():
211
+ """Shutdown the global event loop manager."""
212
+ global _event_loop_manager
213
+ with _event_loop_lock:
214
+ if _event_loop_manager is not None:
215
+ print("shutdown_event_loop_manager: Shutting down global event loop manager.")
216
+ try:
217
+ _event_loop_manager.shutdown()
218
+ except Exception as e:
219
+ print(f"shutdown_event_loop_manager: Error during shutdown: {e}")
220
+ finally:
221
+ _event_loop_manager = None
222
+ else:
223
+ print("shutdown_event_loop_manager: No active event loop manager to shut down.")
224
+
225
+ class AsyncMCPClientWrapper:
226
+ """Wrapper for async MCP client operations."""
227
+
228
+ def __init__(self, url: str):
229
+ self.url = url
230
+ self._mcp_client = None
231
+ self._tools = None
232
+ self._tools_cache_time = None
233
+ self._cache_ttl = 300 # 5 minutes cache
234
+ self._connected = False
235
+
236
+ async def ensure_connected(self):
237
+ """Ensure async connection is established."""
238
+ if not self._connected or self._mcp_client is None:
239
+ try:
240
+ # Create MCP client with SSE transport for Gradio
241
+ self._mcp_client = MCPClient({"url": self.url, "transport": "sse"})
242
+ # Attempt a lightweight operation to confirm connectivity, e.g., get_tools or a custom ping
243
+ # For now, we assume MCPClient constructor success implies basic connectivity.
244
+ # If get_tools is lightweight enough, it can be called here.
245
+ # await self._mcp_client.get_tools() # Example, if get_tools is async and suitable
246
+ self._connected = True
247
+ print(f"✅ Connected to MCP server: {self.url}")
248
+ except Exception as e:
249
+ self._connected = False
250
+ print(f"❌ Failed to connect to {self.url}: {e}")
251
+ raise
252
+
253
+ async def get_tools(self):
254
+ """Get tools asynchronously."""
255
+ current_time = time.time()
256
+
257
+ # Check instance cache
258
+ if (self._tools is not None and
259
+ self._tools_cache_time is not None and
260
+ current_time - self._tools_cache_time < self._cache_ttl):
261
+ return self._tools
262
+
263
+ # Fetch fresh tools
264
+ await self.ensure_connected() # Ensures client is connected
265
+
266
+ if self._mcp_client is None: # Should be caught by ensure_connected, but as a safeguard
267
+ raise RuntimeError("MCP client not connected")
268
+
269
+ try:
270
+ # Assuming MCPClient.get_tools() is a synchronous method based on original structure
271
+ # If it were async, it would be `await self._mcp_client.get_tools()`
272
+ self._tools = self._mcp_client.get_tools()
273
+ self._tools_cache_time = current_time
274
+ tool_names = [tool.name for tool in self._tools] if self._tools else []
275
+ print(f"🔧 Fetched {len(tool_names)} tools from {self.url}: {tool_names}")
276
+
277
+ return self._tools
278
+ except Exception as e:
279
+ print(f"❌ Error fetching tools from {self.url}: {e}")
280
+ # Potentially reset connection status if error indicates a connection problem
281
+ # self._connected = False
282
+ raise
283
+
284
+ async def disconnect(self):
285
+ """Gracefully disconnect."""
286
+ if self._mcp_client and self._connected:
287
+ try:
288
+ # Assuming MCPClient.disconnect() is synchronous
289
+ # If it were async, it would be `await self._mcp_client.disconnect()`
290
+ self._mcp_client.disconnect()
291
+ except Exception as e:
292
+ print(f"Error during MCPClient disconnect for {self.url}: {e}")
293
+ # Log error but continue to mark as disconnected
294
+ pass # Fall through to set _connected = False
295
+ self._connected = False
296
+ self._mcp_client = None
297
+ print(f"🔌 Disconnected from MCP server: {self.url}")
298
+
299
+ class AsyncPersistentMCPClient:
300
+ """Async-aware persistent MCP client that survives multiple requests."""
301
+
302
+ def __init__(self, url: str):
303
+ self.url = url
304
+ self._wrapper = AsyncMCPClientWrapper(url)
305
+ self._loop_manager = None
306
+
307
+ def ensure_connected(self):
308
+ """Sync wrapper for async connection."""
309
+ if self._loop_manager is None:
310
+ self._loop_manager = get_event_loop_manager()
311
+
312
+ conn_step = track_communication("agent", "mcp_client", "connection_ensure", f"Ensuring connection to {self.url}")
313
+ try:
314
+ # Ensure we have a valid loop manager
315
+ if self._loop_manager is None:
316
+ self._loop_manager = get_event_loop_manager()
317
+
318
+ # Additional safety check
319
+ if self._loop_manager is None:
320
+ raise RuntimeError("Failed to create event loop manager")
321
+
322
+ # Pass the coroutine object itself, not its result
323
+ self._loop_manager.run_async(self._wrapper.ensure_connected())
324
+ complete_workflow_step(conn_step, "completed", details={"url": self.url})
325
+ except Exception as e:
326
+ complete_workflow_step(conn_step, "error", details={"error": str(e)})
327
+ raise
328
+
329
+ def get_client(self):
330
+ """Get the underlying MCP client."""
331
+ self.ensure_connected()
332
+ return self._wrapper._mcp_client
333
+
334
+ def get_tools(self):
335
+ """Get tools with enhanced caching and async support."""
336
+ global _global_tools_cache, _global_tools_timestamp
337
+ current_time = time.time()
338
+
339
+ if self._loop_manager is None:
340
+ self._loop_manager = get_event_loop_manager()
341
+
342
+ # Phase 2 Optimization: Check server-specific global cache first
343
+ with _global_connection_lock:
344
+ server_cache_key = self.url
345
+ server_cache = _global_tools_cache.get(server_cache_key, {})
346
+
347
+ if (server_cache and _global_tools_timestamp and
348
+ current_time - _global_tools_timestamp < 300):
349
+ # Track global cache hit
350
+ cache_step = track_communication("mcp_client", "mcp_server", "cache_hit_global", f"Using global cached tools for {self.url}")
351
+ complete_workflow_step(cache_step, "completed", details={
352
+ "tools": list(server_cache.keys()),
353
+ "cache_type": "global_server_specific",
354
+ "server_url": self.url,
355
+ "cache_age": current_time - _global_tools_timestamp
356
+ })
357
+ return list(server_cache.values())
358
+
359
+ # Fetch fresh tools using async
360
+ tools_step = track_communication("mcp_client", "mcp_server", "get_tools", f"Fetching tools from {self.url} (cache refresh)")
361
+ try:
362
+ # Ensure we have a valid loop manager
363
+ if self._loop_manager is None:
364
+ self._loop_manager = get_event_loop_manager()
365
+
366
+ # Additional safety check
367
+ if self._loop_manager is None:
368
+ raise RuntimeError("Failed to create event loop manager")
369
+
370
+ # Pass the coroutine object itself
371
+ tools = self._loop_manager.run_async(self._wrapper.get_tools())
372
+
373
+ # Update global cache
374
+ with _global_connection_lock:
375
+ if tools:
376
+ if server_cache_key not in _global_tools_cache:
377
+ _global_tools_cache[server_cache_key] = {}
378
+
379
+ _global_tools_cache[server_cache_key] = {tool.name: tool for tool in tools}
380
+ _global_tools_timestamp = current_time
381
+
382
+ total_tools = sum(len(server_tools) for server_tools in _global_tools_cache.values())
383
+ print(f"🔄 Global tools cache updated for {self.url}: {len(tools)} tools")
384
+ print(f" Total cached tools across all servers: {total_tools}")
385
+
386
+ tool_names = [tool.name for tool in tools] if tools else []
387
+ complete_workflow_step(tools_step, "completed", details={
388
+ "tools": tool_names,
389
+ "count": len(tool_names),
390
+ "server_url": self.url,
391
+ "cache_status": "refreshed_server_specific",
392
+ "global_cache_servers": len(_global_tools_cache)
393
+ })
394
+ return tools
395
+
396
+ except Exception as e:
397
+ complete_workflow_step(tools_step, "error", details={"error": str(e), "server_url": self.url})
398
+ raise
399
+
400
+ def disconnect(self):
401
+ """Gracefully disconnect."""
402
+ if self._loop_manager and self._wrapper:
403
+ try:
404
+ # Ensure we have a valid loop manager
405
+ if self._loop_manager is None:
406
+ self._loop_manager = get_event_loop_manager()
407
+
408
+ # Additional safety check
409
+ if self._loop_manager is None:
410
+ raise RuntimeError("Failed to create event loop manager")
411
+
412
+ # Pass the coroutine object itself
413
+ self._loop_manager.run_async(self._wrapper.disconnect())
414
+ except RuntimeError as e:
415
+ # Handle cases where the loop might already be closed or unable to run tasks
416
+ print(f"AsyncPersistentMCPClient: Error running disconnect for {self.url} in async loop: {e}")
417
+ except Exception as e:
418
+ print(f"AsyncPersistentMCPClient: General error during disconnect for {self.url}: {e}")
419
+
420
+ def get_mcp_client(url: str = "http://localhost:7859/gradio_api/mcp/sse") -> AsyncPersistentMCPClient:
421
+ """Get or create an MCP client with enhanced global connection pooling."""
422
+ # Phase 2 Optimization: Use global connection pool
423
+ with _global_connection_lock:
424
+ if url not in _global_connection_pool:
425
+ conn_step = track_communication("agent", "mcp_client", "connection_create", f"Creating new global connection to {url}")
426
+ _global_connection_pool[url] = AsyncPersistentMCPClient(url)
427
+ complete_workflow_step(conn_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
428
+ else:
429
+ # Track connection reuse
430
+ reuse_step = track_communication("agent", "mcp_client", "connection_reuse", f"Reusing global connection to {url}")
431
+ complete_workflow_step(reuse_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
432
+
433
+ return _global_connection_pool[url]
434
+
435
+ def get_global_model() -> 'CachedLocalInferenceModel':
436
+ """Get or create global model instance for Phase 2 optimization."""
437
+ global _global_model_instance
438
+
439
+ with _global_model_lock:
440
+ if _global_model_instance is None:
441
+ model_step = track_workflow_step("model_init_global", "Initializing global model instance")
442
+
443
+ # CRITICAL FIX: Create and assign BEFORE initialization
444
+ _global_model_instance = CachedLocalInferenceModel()
445
+
446
+ # Now initialize the model
447
+ try:
448
+ _global_model_instance.ensure_initialized()
449
+ complete_workflow_step(model_step, "completed", details={"model_type": "global_cached"})
450
+ print(f"🤖 Global model instance created and initialized")
451
+ except Exception as e:
452
+ # If initialization fails, reset global instance
453
+ _global_model_instance = None
454
+ complete_workflow_step(model_step, "error", details={"error": str(e)})
455
+ raise
456
+ else:
457
+ # Track model reuse
458
+ reuse_step = track_workflow_step("model_reuse", "Reusing global model instance")
459
+ complete_workflow_step(reuse_step, "completed", details={"model_type": "global_cached"})
460
+
461
+ return _global_model_instance
462
+
463
+ def reset_global_state():
464
+ """Reset global state for testing purposes with server-specific cache awareness."""
465
+ global _global_tools_cache, _global_tools_timestamp, _global_model_instance, _global_connection_pool, _event_loop_manager
466
+
467
+ with _global_connection_lock:
468
+ # FIXED: Clear server-specific cache structure (don't rebind!)
469
+ _global_tools_cache.clear() # Now clears {url: {tool_name: tool}} structure
470
+ _global_tools_timestamp = None
471
+
472
+ # Disconnect all connections but keep pool structure
473
+ for client in _global_connection_pool.values():
474
+ try:
475
+ client.disconnect()
476
+ except:
477
+ pass
478
+
479
+ with _global_model_lock:
480
+ # Don't reset model instance - it should persist
481
+ pass
482
+
483
+ print("🔄 Global state reset for testing (server-specific cache cleared)")
484
+
485
+ # Enhanced LocalInferenceModel with workflow tracking
486
+ class CachedLocalInferenceModel(Model):
487
+ """Model with enhanced caching and session persistence."""
488
+
489
+ def __init__(self):
490
+ super().__init__()
491
+ self._response_cache = {}
492
+ self._cache_hits = 0
493
+ self._cache_misses = 0
494
+ self._model_ready = False
495
+
496
+ def ensure_initialized(self):
497
+ """Lazy initialization of the model."""
498
+ if not self._model_ready:
499
+ init_step = track_workflow_step("model_init", "Initializing inference model (lazy)")
500
+ try:
501
+ initialize()
502
+ self._model_ready = True
503
+ complete_workflow_step(init_step, "completed")
504
+ except Exception as e:
505
+ complete_workflow_step(init_step, "error", details={"error": str(e)})
506
+ raise
507
+
508
+ def generate(self, messages: Any, **kwargs: Any) -> Any:
509
+ self.ensure_initialized()
510
+
511
+ prompt = self._format_messages(messages)
512
+
513
+ # Enhanced cache with hash-based lookup
514
+ cache_key = hash(prompt)
515
+ if cache_key in self._response_cache:
516
+ self._cache_hits += 1
517
+ cached_response = self._response_cache[cache_key]
518
+
519
+ # Track cache hit
520
+ cache_step = track_communication("agent", "llm_service", "cache_hit", "Using cached response")
521
+ complete_workflow_step(cache_step, "completed", details={
522
+ "cache_hits": self._cache_hits,
523
+ "cache_misses": self._cache_misses,
524
+ "cache_ratio": self._cache_hits / (self._cache_hits + self._cache_misses)
525
+ })
526
+
527
+ return ModelResponse(cached_response.content, prompt)
528
+
529
+ self._cache_misses += 1
530
+
531
+ # Track LLM call
532
+ llm_step = track_communication("agent", "llm_service", "generate_request", "Generating new response")
533
+
534
+ try:
535
+ enhanced_prompt = self._enhance_prompt_for_tools(prompt)
536
+
537
+ response_text = generate_content(
538
+ prompt=enhanced_prompt,
539
+ model_name=kwargs.get('model_name'),
540
+ allow_fallbacks=True,
541
+ generation_config={
542
+ 'temperature': kwargs.get('temperature', 0.3),
543
+ 'max_output_tokens': kwargs.get('max_tokens', 512)
544
+ }
545
+ )
546
+
547
+ # Validate and fix response format
548
+ if not self._is_valid_code_response(response_text):
549
+ response_text = self._fix_response_format(response_text, prompt)
550
+
551
+ response = ModelResponse(str(response_text), prompt)
552
+
553
+ # Smart cache management (keep most recent 10 responses)
554
+ if len(self._response_cache) >= 10:
555
+ # Remove oldest entry (simple FIFO)
556
+ oldest_key = next(iter(self._response_cache))
557
+ del self._response_cache[oldest_key]
558
+
559
+ self._response_cache[cache_key] = response
560
+
561
+ complete_workflow_step(llm_step, "completed", details={
562
+ "cache_status": "new",
563
+ "input_tokens": response.token_usage.input_tokens,
564
+ "output_tokens": response.token_usage.output_tokens,
565
+ "model": response.model
566
+ })
567
+
568
+ return response
569
+
570
+ except Exception as e:
571
+ fallback_response = self._create_fallback_response(prompt, str(e))
572
+ complete_workflow_step(llm_step, "error", details={"error": str(e)})
573
+ return ModelResponse(fallback_response, prompt)
574
+
575
+ def _enhance_prompt_for_tools(self, prompt: str) -> str:
576
+ """Enhance the prompt with better tool usage examples."""
577
+ if "sentiment" in prompt.lower():
578
+ tool_example = """
579
+ IMPORTANT: When calling sentiment_analysis, use keyword arguments only:
580
+ Correct: sentiment_analysis(text="your text here")
581
+ Wrong: sentiment_analysis("your text here")
582
+
583
+ Example:
584
+ ```py
585
+ text = "this is horrible"
586
+ result = sentiment_analysis(text=text)
587
+ final_answer(result)
588
+ ```"""
589
+ return prompt + "\n" + tool_example
590
+ return prompt
591
+
592
+ def _format_messages(self, messages: Any) -> str:
593
+ """Convert messages to a single prompt string."""
594
+ if isinstance(messages, str):
595
+ return messages
596
+ elif isinstance(messages, list):
597
+ prompt_parts = []
598
+ for msg in messages:
599
+ if isinstance(msg, dict):
600
+ if 'content' in msg:
601
+ content = msg['content']
602
+ role = msg.get('role', 'user')
603
+ if isinstance(content, list):
604
+ text_parts = [part.get('text', '') for part in content if part.get('type') == 'text']
605
+ content = ' '.join(text_parts)
606
+ prompt_parts.append(f"{role}: {content}")
607
+ elif 'text' in msg:
608
+ prompt_parts.append(msg['text'])
609
+ elif hasattr(msg, 'content'):
610
+ prompt_parts.append(str(msg.content))
611
+ else:
612
+ prompt_parts.append(str(msg))
613
+ return '\n'.join(prompt_parts)
614
+ else:
615
+ return str(messages)
616
+
617
+ def _is_valid_code_response(self, response: str) -> bool:
618
+ """Check if response contains valid code block format."""
619
+ code_pattern = r'```(?:py|python)?\s*\n(.*?)\n```'
620
+ return bool(re.search(code_pattern, response, re.DOTALL))
621
+
622
+ def _fix_response_format(self, response: str, original_prompt: str) -> str:
623
+ """Try to fix response format to match expected pattern."""
624
+ # Attempt to remove or comment out "Thoughts:" if not in a code block already
625
+ # This is a common source of SyntaxError if the LLM includes it directly
626
+ if "Thoughts:" in response and not "```" in response.split("Thoughts:")[0]:
627
+ # If "Thoughts:" appears before any code block, comment it out
628
+ response = response.replace("Thoughts:", "# Thoughts:", 1)
629
+
630
+ if "sentiment" in original_prompt.lower():
631
+ text_to_analyze = "neutral text"
632
+ if "this is horrible" in original_prompt:
633
+ text_to_analyze = "this is horrible"
634
+ elif "awful" in original_prompt:
635
+ text_to_analyze = "awful"
636
+
637
+ return f"""Thoughts: I need to analyze the sentiment of the given text using the sentiment_analysis tool.
638
+ Code:
639
+ ```py
640
+ text = "{text_to_analyze}"
641
+ result = sentiment_analysis(text=text)
642
+ final_answer(result)
643
+ ```<end_code>"""
644
+
645
+ if "```" in response and ("Thoughts:" in response or "Code:" in response):
646
+ return response
647
+
648
+ clean_response = response.replace('"', '\\"').replace('\n', '\\n')
649
+ return f"""Thoughts: Processing the user's request.
650
+ Code:
651
+ ```py
652
+ result = "{clean_response}"
653
+ final_answer(result)
654
+ ```<end_code>"""
655
+
656
+ def _create_fallback_response(self, prompt: str, error_msg: str) -> str:
657
+ """Create a valid fallback response when the model fails."""
658
+ return f"""Thoughts: The AI service is experiencing issues, providing a fallback response.
659
+ Code:
660
+ ```py
661
+ error_message = "I apologize, but the AI service is temporarily experiencing high load. Please try again in a moment."
662
+ final_answer(error_message)
663
+ ```<end_code>"""
664
+
665
+ class TokenUsage:
666
+ def __init__(self, input_tokens: int = 0, output_tokens: int = 0):
667
+ self.input_tokens = input_tokens
668
+ self.output_tokens = output_tokens
669
+ self.total_tokens = input_tokens + output_tokens
670
+ self.prompt_tokens = input_tokens
671
+ self.completion_tokens = output_tokens
672
+
673
+ class ModelResponse:
674
+ def __init__(self, content: str, prompt: str = ""):
675
+ self.content = content
676
+ self.text = content
677
+ estimated_input_tokens = len(prompt.split()) if prompt else 0
678
+ estimated_output_tokens = len(content.split()) if content else 0
679
+ self.token_usage = TokenUsage(estimated_input_tokens, estimated_output_tokens)
680
+ self.finish_reason = 'stop'
681
+ self.model = 'local-inference'
682
+
683
+ def __str__(self):
684
+ return self.content
685
+
686
+ # Global variables
687
+ _mcp_client = None
688
+ _tools = None
689
+ _model = None
690
+ _agent = None
691
+ _initialized = False
692
+ _initialization_lock = threading.Lock()
693
+
694
+ def initialize_agent():
695
+ """Initialize the agent components with Hugging Face Spaces MCP servers."""
696
+ global _mcp_client, _tools, _model, _agent, _initialized
697
+
698
+ with _initialization_lock:
699
+ if _initialized:
700
+ skip_step = track_workflow_step("agent_init_skip", "Agent already initialized - using cached instance")
701
+ complete_workflow_step(skip_step, "completed", details={"optimization": "session_persistence"})
702
+ return
703
+
704
+ try:
705
+ print("Initializing MCP agent...")
706
+
707
+ agent_init_step = track_workflow_step("agent_init", "Initializing MCP agent components")
708
+
709
+ # Get clients for Hugging Face Spaces servers
710
+ all_tools = []
711
+ tool_names = set()
712
+
713
+ # Semantic Search & Keywords server
714
+ try:
715
+ semantic_client = get_mcp_client("https://nlarchive-mcp-semantic-keywords.hf.space/gradio_api/mcp/sse")
716
+ semantic_tools = semantic_client.get_tools()
717
+ for tool in semantic_tools:
718
+ if tool.name not in tool_names:
719
+ all_tools.append(tool)
720
+ tool_names.add(tool.name)
721
+ print(f"Connected to semantic server: {len(semantic_tools)} tools - {[t.name for t in semantic_tools]}")
722
+ except Exception as e:
723
+ print(f"WARNING: Semantic server unavailable: {e}")
724
+
725
+ # Token Counter server
726
+ try:
727
+ token_client = get_mcp_client("https://nlarchive-mcp-gr-token-counter.hf.space/gradio_api/mcp/sse")
728
+ token_tools = token_client.get_tools()
729
+ for tool in token_tools:
730
+ if tool.name not in tool_names:
731
+ all_tools.append(tool)
732
+ tool_names.add(tool.name)
733
+ print(f"Connected to token counter server: {len(token_tools)} tools - {[t.name for t in token_tools]}")
734
+ except Exception as e:
735
+ print(f"WARNING: Token counter server unavailable: {e}")
736
+
737
+ # Sentiment Analysis server
738
+ try:
739
+ sentiment_client = get_mcp_client("https://nlarchive-mcp-sentiment.hf.space/gradio_api/mcp/sse")
740
+ sentiment_tools = sentiment_client.get_tools()
741
+ for tool in sentiment_tools:
742
+ if tool.name not in tool_names:
743
+ all_tools.append(tool)
744
+ tool_names.add(tool.name)
745
+ print(f"Connected to sentiment analysis server: {len(sentiment_tools)} tools - {[t.name for t in sentiment_tools]}")
746
+ except Exception as e:
747
+ print(f"WARNING: Sentiment analysis server unavailable: {e}")
748
+
749
+ _tools = all_tools
750
+ _model = get_global_model()
751
+
752
+ # Create agent with unique tools only
753
+ _agent = CodeAgent(tools=_tools, model=_model)
754
+
755
+ complete_workflow_step(agent_init_step, "completed", details={
756
+ "tools_count": len(_tools),
757
+ "unique_tool_names": list(tool_names),
758
+ "servers_connected": 3
759
+ })
760
+
761
+ _initialized = True
762
+ print(f"Agent initialized with {len(_tools)} unique tools: {list(tool_names)}")
763
+
764
+ except Exception as e:
765
+ print(f"Agent initialization failed: {e}")
766
+ _model = get_global_model()
767
+ _agent = CodeAgent(tools=[], model=_model)
768
+ _initialized = True
769
+ print("Agent initialized in fallback mode")
770
+
771
+ def is_agent_initialized() -> bool:
772
+ """Check if the agent is initialized."""
773
+ return _initialized
774
+
775
+ def run_agent(message: str) -> str:
776
+ """Send message through the agent with comprehensive tracking."""
777
+ if not _initialized:
778
+ initialize_agent()
779
+ if _agent is None:
780
+ raise RuntimeError("Agent not properly initialized")
781
+
782
+ # Track agent processing
783
+ process_step = track_workflow_step("agent_process", f"Processing: {message}")
784
+
785
+ try:
786
+ # Enhanced tool tracking
787
+ tool_step: Optional[str] = None
788
+ detected_tools = []
789
+
790
+ # Detect potential tool usage
791
+ if any(keyword in message.lower() for keyword in ['sentiment', 'analyze', 'feeling']):
792
+ detected_tools.append('sentiment_analysis')
793
+ if any(keyword in message.lower() for keyword in ['token', 'count']):
794
+ detected_tools.extend(['count_tokens_openai_gpt4', 'count_tokens_bert_family'])
795
+ if any(keyword in message.lower() for keyword in ['semantic', 'similar', 'keyword']):
796
+ detected_tools.extend(['semantic_similarity', 'extract_semantic_keywords'])
797
+
798
+ if detected_tools:
799
+ tool_step = track_communication("agent", "mcp_server", "tool_call",
800
+ f"Executing tools {detected_tools} for: {message[:50]}...")
801
+
802
+ result = _agent.run(message)
803
+
804
+ # Complete tool step if it was tracked
805
+ if tool_step is not None:
806
+ complete_workflow_step(tool_step, "completed", details={
807
+ "result": str(result)[:100],
808
+ "detected_tools": detected_tools
809
+ })
810
+
811
+ complete_workflow_step(process_step, "completed", details={
812
+ "result_length": len(str(result)),
813
+ "detected_tools": detected_tools
814
+ })
815
+
816
+ return str(result)
817
+
818
+ except Exception as e:
819
+ error_msg = str(e)
820
+ print(f"Agent execution error: {error_msg}")
821
+
822
+ complete_workflow_step(process_step, "error", details={"error": error_msg})
823
+
824
+ # Enhanced error responses
825
+ if "503" in error_msg or "overloaded" in error_msg.lower():
826
+ return "I apologize, but the AI service is currently experiencing high demand. Please try again in a few moments."
827
+ elif "rate limit" in error_msg.lower():
828
+ return "The service is currently rate-limited. Please wait a moment before trying again."
829
+ elif "event loop" in error_msg.lower():
830
+ return "There was an async processing issue. The system is recovering. Please try again."
831
+ else:
832
+ return "I encountered an error while processing your request. Please try rephrasing your question or try again later."
833
+
834
+ def disconnect():
835
+ """Cleanly disconnect connections with global pool management."""
836
+ global _mcp_client, _initialized
837
+ disconnect_step = track_workflow_step("agent_disconnect", "Disconnecting MCP client")
838
+
839
+ try:
840
+ # Phase 2 Optimization: Preserve global connections for reuse
841
+ with _global_connection_lock:
842
+ preserved_connections = 0
843
+ for url, client in _global_connection_pool.items():
844
+ try:
845
+ # Keep connections alive but mark as idle
846
+ if hasattr(client, '_last_used'):
847
+ client._last_used = time.time()
848
+ preserved_connections += 1
849
+ except:
850
+ pass
851
+
852
+ complete_workflow_step(disconnect_step, "completed", details={
853
+ "preserved_connections": preserved_connections,
854
+ "optimization": "connection_persistence"
855
+ })
856
+ except Exception as e:
857
+ complete_workflow_step(disconnect_step, "error", details={"error": str(e)})
858
+ finally:
859
+ # Don't reset global state - preserve for next session
860
+ _initialized = False
861
+
862
+ def initialize_session():
863
+ """Initialize the persistent session - alias for initialize_agent."""
864
+ initialize_agent()
865
+
866
+ def is_session_initialized() -> bool:
867
+ """Check if the persistent session is initialized - alias for is_agent_initialized."""
868
+ return is_agent_initialized()
869
+
870
+ # Make sure these are exported for imports
871
+ __all__ = [
872
+ 'run_agent', 'initialize_agent', 'is_agent_initialized', 'disconnect',
873
+ 'initialize_session', 'is_session_initialized',
874
+ 'get_mcp_client', 'get_global_model', 'reset_global_state',
875
+ '_global_tools_cache', '_global_connection_pool', '_global_model_instance',
876
+ '_global_connection_lock', '_global_model_lock'
877
+ ]
878
+
879
+ # Register cleanup function
880
+ def cleanup_global_resources():
881
+ """Cleanup function for graceful shutdown."""
882
+ global _global_connection_pool, _event_loop_manager, _global_connection_lock, _event_loop_lock
883
+
884
+ print("Cleaning up global resources...")
885
+
886
+ with _global_connection_lock:
887
+ for client in _global_connection_pool.values():
888
+ try:
889
+ client.disconnect()
890
+ except:
891
+ pass
892
+ _global_connection_pool.clear()
893
+
894
+ # Shutdown event loop manager
895
+ with _event_loop_lock:
896
+ if _event_loop_manager:
897
+ try:
898
+ _event_loop_manager.shutdown()
899
+ except:
900
+ pass
901
+ _event_loop_manager = None
902
+
903
+ # Register cleanup on exit
904
+ atexit.register(cleanup_global_resources)