chiara5122 commited on
Commit
38e45e3
·
1 Parent(s): 689022d

added table extractor tool

Browse files
Files changed (1) hide show
  1. tools/table_extractor_tool.py +66 -70
tools/table_extractor_tool.py CHANGED
@@ -1,106 +1,102 @@
1
  from smolagents import Tool
2
- from tabula import read_pdf
3
  import pandas as pd
4
- from typing import Optional, Dict, Any
 
5
 
6
  class TableExtractorTool(Tool):
7
  """
8
- Tool to extract tables from PDFs/webpages and answer queries about them.
9
-
10
- Args:
11
- file_path (str): Path to PDF file (optional)
12
- url (str): URL of webpage containing tables (optional)
13
- query (str): Natural language question about the table data (optional)
14
-
15
- Returns:
16
- str: Extracted table data or answer to query
17
  """
18
-
19
- name = "extract_table"
20
- description = "Extracts tables from PDFs or webpages and answers questions about the data"
21
-
22
  inputs = {
23
  "file_path": {
24
  "type": "string",
25
- "description": "Path to PDF file (either file_path or url required)",
26
- "required": False
27
  },
28
- "url": {
29
  "type": "string",
30
- "description": "URL of webpage containing tables (either file_path or url required)",
31
- "required": False
 
32
  },
33
  "query": {
34
  "type": "string",
35
- "description": "Natural language question about the table data",
36
- "required": False
 
37
  }
38
  }
39
-
40
  output_type = "string"
41
 
42
- def forward(self, file_path: Optional[str] = None,
43
- url: Optional[str] = None,
 
44
  query: Optional[str] = None) -> str:
45
 
46
- # Validate input
47
- if not file_path and not url:
48
- return "Error: Either file_path or url must be provided"
49
-
50
  try:
51
- # Case 1: Extract from PDF
52
- if file_path and file_path.endswith(".pdf"):
53
- tables = read_pdf(file_path, pages="all", multiple_tables=True)
54
- df = pd.concat(tables) if tables else None
55
 
56
- # Case 2: Extract from HTML (webpage)
57
- elif url:
58
- dfs = pd.read_html(url)
59
- df = dfs[0] if dfs else None
60
 
61
- if df is None:
62
- return "No tables found in the input source"
 
 
 
 
63
 
64
- # Answer query if provided
65
- if query:
66
- return self._answer_query(df, query)
67
- return df.to_string()
68
 
69
  except Exception as e:
70
- return f"Error processing table data: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _answer_query(self, df: pd.DataFrame, query: str) -> str:
73
- """Helper method to answer questions about the table data"""
 
 
74
  try:
75
- query = query.lower()
76
-
77
- # Example simple queries - you could expand this or integrate an LLM
78
- if "total" in query and "sum" in query:
79
- if "revenue" in query:
80
- col = "Revenue"
81
- elif "sales" in query:
82
- col = "Sales"
83
- else:
84
- # Try to find a numeric column
85
- numeric_cols = df.select_dtypes(include=['number']).columns
86
- col = numeric_cols[0] if len(numeric_cols) > 0 else None
87
-
88
- if col:
89
- return f"Total {col}: {df[col].sum()}"
90
 
 
91
  elif "average" in query or "mean" in query:
92
- # Find the most likely column referenced in query
93
- for col in df.columns:
94
  if col.lower() in query:
95
  return f"Average {col}: {df[col].mean():.2f}"
96
-
97
- # Default to first numeric column
98
- numeric_cols = df.select_dtypes(include=['number']).columns
99
- if len(numeric_cols) > 0:
100
- return f"Average {numeric_cols[0]}: {df[numeric_cols[0]].mean():.2f}"
101
 
102
- # Fallback: return the table
103
- return f"Here's the table data:\n{df.to_string()}\n\nQuery '{query}' not fully understood."
 
 
 
 
 
 
 
104
 
105
  except Exception as e:
106
- return f"Error answering query: {str(e)}\nTable data:\n{df.to_string()}"
 
1
  from smolagents import Tool
 
2
  import pandas as pd
3
+ from typing import Optional
4
+ import os
5
 
6
  class TableExtractorTool(Tool):
7
  """
8
+ Extracts tables from Excel (.xlsx, .xls) or CSV files and answers queries.
9
+ Auto-detects file type based on extension.
 
 
 
 
 
 
 
10
  """
11
+ name = "table_extractor"
12
+ description = "Reads Excel/CSV files and answers questions about tabular data"
 
 
13
  inputs = {
14
  "file_path": {
15
  "type": "string",
16
+ "description": "Path to Excel/CSV file"
 
17
  },
18
+ "sheet_name": {
19
  "type": "string",
20
+ "description": "Sheet name (Excel only, optional)",
21
+ "required": False,
22
+ "nullable": True
23
  },
24
  "query": {
25
  "type": "string",
26
+ "description": "Question about the data (e.g., 'total sales')",
27
+ "required": False,
28
+ "nullable": True
29
  }
30
  }
 
31
  output_type = "string"
32
 
33
+ def forward(self,
34
+ file_path: str,
35
+ sheet_name: Optional[str] = None,
36
  query: Optional[str] = None) -> str:
37
 
 
 
 
 
38
  try:
39
+ # Validate file exists
40
+ if not os.path.exists(file_path):
41
+ return f"Error: File not found at {file_path}"
 
42
 
43
+ # Read file based on extension
44
+ ext = os.path.splitext(file_path)[1].lower()
 
 
45
 
46
+ if ext in ('.xlsx', '.xls'):
47
+ df = self._read_excel(file_path, sheet_name)
48
+ elif ext == '.csv':
49
+ df = pd.read_csv(file_path)
50
+ else:
51
+ return f"Error: Unsupported file type {ext}"
52
 
53
+ if df.empty:
54
+ return "Error: No data found in file."
55
+
56
+ return self._answer_query(df, query) if query else df.to_string()
57
 
58
  except Exception as e:
59
+ return f"Error processing file: {str(e)}"
60
+
61
+ def _read_excel(self, path: str, sheet_name: Optional[str]) -> pd.DataFrame:
62
+ """Read Excel file with sheet selection logic"""
63
+ if sheet_name:
64
+ return pd.read_excel(path, sheet_name=sheet_name)
65
+
66
+ # Auto-detect first non-empty sheet
67
+ sheets = pd.ExcelFile(path).sheet_names
68
+ for sheet in sheets:
69
+ df = pd.read_excel(path, sheet_name=sheet)
70
+ if not df.empty:
71
+ return df
72
+ return pd.DataFrame() # Return empty if all sheets are blank
73
 
74
  def _answer_query(self, df: pd.DataFrame, query: str) -> str:
75
+ """Handles queries with pandas operations"""
76
+ query = query.lower()
77
+
78
  try:
79
+ # SUM QUERIES (e.g., "total revenue")
80
+ if "total" in query or "sum" in query:
81
+ for col in df.select_dtypes(include='number').columns:
82
+ if col.lower() in query:
83
+ return f"Total {col}: {df[col].sum():.2f}"
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # AVERAGE QUERIES (e.g., "average price")
86
  elif "average" in query or "mean" in query:
87
+ for col in df.select_dtypes(include='number').columns:
 
88
  if col.lower() in query:
89
  return f"Average {col}: {df[col].mean():.2f}"
 
 
 
 
 
90
 
91
+ # FILTER QUERIES (e.g., "show sales > 1000")
92
+ elif ">" in query or "<" in query:
93
+ col = next((c for c in df.columns if c.lower() in query), None)
94
+ if col:
95
+ filtered = df.query(query.replace(col, f"`{col}`"))
96
+ return filtered.to_string()
97
+
98
+ # DEFAULT: Return full table with column names
99
+ return f"Data:\nColumns: {', '.join(df.columns)}\n\n{df.to_string()}"
100
 
101
  except Exception as e:
102
+ return f"Query failed: {str(e)}\nAvailable columns: {', '.join(df.columns)}"