Spaces:
Running
Running
| """ | |
| AIFinder Dataset Evaluator | |
| Supports various HuggingFace dataset formats for evaluation. | |
| """ | |
| import os | |
| import re | |
| import json | |
| import random | |
| from collections import defaultdict | |
| from typing import Any | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| SUPPORTED_FORMATS = { | |
| "teichai_healer": { | |
| "name": "TeichAI Healer Format", | |
| "description": "TeichAI Healer-Alpha format with 'prompt' and 'response' fields", | |
| "examples": ["TeichAI/Healer-Alpha-16k"], | |
| "check": lambda row: ( | |
| "prompt" in row | |
| and "response" in row | |
| and isinstance(row.get("prompt"), (str, dict)) | |
| and isinstance(row.get("response"), (str, dict)) | |
| ), | |
| }, | |
| "teichai": { | |
| "name": "TeichAI Format", | |
| "description": "TeichAI dataset format with 'conversations' or 'messages' containing role/content", | |
| "examples": [ | |
| "TeichAI/claude-4.5-opus-high-reasoning-250x", | |
| "TeichAI/Claude-3.5-Sonnet-128k", | |
| ], | |
| "check": lambda row: _check_conversations_format(row), | |
| }, | |
| "combined": { | |
| "name": "Combined Outputs", | |
| "description": "Dataset with 'output', 'outputs', 'generated' or 'completion' field", | |
| "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"], | |
| "check": lambda row: ( | |
| "prompt" not in row | |
| and "response" not in row | |
| and not _check_conversations_format(row) | |
| and ( | |
| any(k in row for k in ["output", "outputs", "generated", "completion"]) | |
| or ( | |
| isinstance(row.get("data"), str) | |
| or isinstance(row.get("example"), str) | |
| ) | |
| ) | |
| ), | |
| }, | |
| "conversations": { | |
| "name": "Conversations Format", | |
| "description": "Dataset with 'conversations' or 'messages' field containing role/content pairs", | |
| "examples": [ | |
| "TeichAI/claude-4.5-opus-high-reasoning-250x", | |
| "ianncity/Hunter-Alpha-SFT-300000x", | |
| ], | |
| "check": lambda row: _check_conversations_format(row), | |
| }, | |
| "chat": { | |
| "name": "Chat Format", | |
| "description": "Dataset with 'chat' or 'dialogue' field", | |
| "examples": ["some/chat-dataset"], | |
| "check": lambda row: ("chat" in row.keys() or "dialogue" in row.keys()), | |
| }, | |
| "text": { | |
| "name": "Text Field", | |
| "description": "Dataset with a 'text' field containing the response", | |
| "examples": ["some/text-dataset"], | |
| "check": lambda row: "text" in row and isinstance(row.get("text"), str), | |
| }, | |
| "response": { | |
| "name": "Response Field", | |
| "description": "Dataset with 'response' or 'output' field", | |
| "examples": ["some/response-dataset"], | |
| "check": lambda row: "response" in row or "output" in row, | |
| }, | |
| "content": { | |
| "name": "Content Field", | |
| "description": "Dataset with 'content' field (single message)", | |
| "examples": ["some/content-dataset"], | |
| "check": lambda row: "content" in row and isinstance(row.get("content"), str), | |
| }, | |
| "messages": { | |
| "name": "Messages Array", | |
| "description": "Dataset where each row is an array of message objects", | |
| "examples": ["some/messages-dataset"], | |
| "check": lambda row: isinstance(row, list) | |
| and len(row) > 0 | |
| and isinstance(row[0], dict), | |
| }, | |
| "sft": { | |
| "name": "SFT Format", | |
| "description": "Supervised Fine-Tuning format with 'prompt' and 'response' or 'completion'", | |
| "examples": ["some/sft-dataset"], | |
| "check": lambda row: "prompt" in row | |
| and ("response" in row or "completion" in row), | |
| }, | |
| "qa": { | |
| "name": "Q&A Format", | |
| "description": "Question-Answer format with 'question' and 'answer' fields", | |
| "examples": ["some/qa-dataset"], | |
| "check": lambda row: "question" in row and "answer" in row, | |
| }, | |
| "combined": { | |
| "name": "Combined Outputs", | |
| "description": "Dataset with 'input', 'output', 'outputs' or combined text field", | |
| "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"], | |
| "check": lambda row: any( | |
| k in row | |
| for k in ["output", "outputs", "combined", "generated", "completion"] | |
| ) | |
| or (isinstance(row.get("data"), str) or isinstance(row.get("example"), str)), | |
| }, | |
| "completion": { | |
| "name": "Completion Format", | |
| "description": "Dataset with 'completion' field (like OpenAI fine-tuning)", | |
| "examples": ["some/completion-dataset"], | |
| "check": lambda row: "completion" in row | |
| and isinstance(row.get("completion"), str), | |
| }, | |
| "generations": { | |
| "name": "Generations Format", | |
| "description": "Dataset with 'generations' or 'generation' field (LLM outputs)", | |
| "examples": ["some/generations-dataset"], | |
| "check": lambda row: "generations" in row or "generation" in row, | |
| }, | |
| } | |
| def _check_conversations_format(row): | |
| """Check if row has conversations/messages with proper role/content structure.""" | |
| conv_key = ( | |
| "conversations" | |
| if "conversations" in row | |
| else "messages" | |
| if "messages" in row | |
| else None | |
| ) | |
| if not conv_key: | |
| return False | |
| convos = row.get(conv_key) | |
| if not isinstance(convos, list) or not convos: | |
| return False | |
| first_msg = convos[0] | |
| if isinstance(first_msg, dict): | |
| return "role" in first_msg and "content" in first_msg | |
| return False | |
| def detect_format(rows, sample_size=10): | |
| """Detect the dataset format from sample rows.""" | |
| if not rows: | |
| return None, [] | |
| sample = rows[:sample_size] | |
| for fmt_name, fmt_info in SUPPORTED_FORMATS.items(): | |
| check_func = fmt_info["check"] | |
| matches = 0 | |
| for row in sample: | |
| try: | |
| if check_func(row): | |
| matches += 1 | |
| except: | |
| pass | |
| if matches >= len(sample) * 0.6: | |
| return fmt_name, SUPPORTED_FORMATS[fmt_name] | |
| return None, [] | |
| def _parse_msg(msg): | |
| """Parse a message that may be a dict or a JSON string.""" | |
| if isinstance(msg, dict): | |
| return msg | |
| if isinstance(msg, str): | |
| try: | |
| parsed = json.loads(msg) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except (ValueError, Exception): | |
| pass | |
| return {} | |
| def _extract_response_only(content): | |
| """Extract only the final response, stripping CoT blocks.""" | |
| if not content: | |
| return "" | |
| think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL) | |
| if think_match: | |
| response = think_match.group(1).strip() | |
| if response: | |
| return response | |
| return content | |
| def extract_texts_conversations(rows): | |
| """Extract from conversations/messages format.""" | |
| texts = [] | |
| for row in rows: | |
| convos = row.get("conversations") or row.get("messages") or [] | |
| if not convos: | |
| continue | |
| for msg in convos: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "gpt", "model", "ai") and content: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_chat(rows): | |
| """Extract from chat/dialogue format.""" | |
| texts = [] | |
| for row in rows: | |
| chat = row.get("chat") or row.get("dialogue") or [] | |
| if isinstance(chat, list): | |
| for msg in chat: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "ai") and content: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_text_field(rows, field="text"): | |
| """Extract from a text field.""" | |
| texts = [] | |
| for row in rows: | |
| content = row.get(field, "") | |
| if content and len(str(content)) > 50: | |
| response_only = _extract_response_only(str(content)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_sft(rows): | |
| """Extract from SFT format (prompt + response/completion).""" | |
| texts = [] | |
| for row in rows: | |
| response = row.get("response") or row.get("completion") or "" | |
| if response and len(str(response)) > 50: | |
| response_only = _extract_response_only(str(response)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_qa(rows): | |
| """Extract from Q&A format (use answer as response).""" | |
| texts = [] | |
| for row in rows: | |
| answer = row.get("answer", "") | |
| if answer and len(str(answer)) > 50: | |
| response_only = _extract_response_only(str(answer)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_messages_array(rows): | |
| """Extract from messages array format.""" | |
| texts = [] | |
| for row in rows: | |
| if isinstance(row, list): | |
| for msg in row: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "ai", "model") and content: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def extract_texts_teichai_healer(rows): | |
| """Extract from TeichAI Healer-Alpha format (prompt + response fields).""" | |
| texts = [] | |
| for row in rows: | |
| response = row.get("response") | |
| if response: | |
| if isinstance(response, dict): | |
| response = response.get("content") or response.get("text") or "" | |
| if response and len(str(response)) > 50: | |
| response_only = _extract_response_only(str(response)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| return texts | |
| def _get_dataset_size(dataset_id, load_kwargs): | |
| """Get dataset size without loading all data.""" | |
| try: | |
| ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) | |
| return ds.info.num_rows | |
| except Exception: | |
| pass | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| return len(df) | |
| except Exception: | |
| return 0 | |
| def _streaming_download_with_progress( | |
| dataset_id, load_kwargs, progress_callback=None, max_rows=None | |
| ): | |
| """Download dataset using streaming with progress tracking.""" | |
| import pandas as pd | |
| total_rows = _get_dataset_size(dataset_id, load_kwargs) | |
| print(f"[PROGRESS] Dataset size: {total_rows} rows", flush=True) | |
| download_limit = max_rows if max_rows and max_rows < total_rows else total_rows | |
| if progress_callback: | |
| progress_callback(0, download_limit, "fetching_info") | |
| print(f"[PROGRESS] Initial callback: 0/{download_limit}", flush=True) | |
| try: | |
| ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) | |
| rows = [] | |
| for i, row in enumerate(tqdm(ds, desc="Downloading", unit="rows")): | |
| rows.append(row) | |
| if progress_callback: | |
| progress_callback(i + 1, download_limit, "downloading") | |
| if i % 100 == 0: | |
| print(f"[PROGRESS] Downloaded {i + 1}/{download_limit}", flush=True) | |
| if max_rows and i + 1 >= max_rows: | |
| print(f"[PROGRESS] Stopping at {i + 1} rows", flush=True) | |
| break | |
| return rows, min(len(rows), total_rows or len(rows)) | |
| except Exception as e: | |
| print(f"[PROGRESS] Streaming failed: {e}", flush=True) | |
| pass | |
| try: | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| if max_rows and max_rows < len(df): | |
| df = df.head(max_rows) | |
| print(f"[PROGRESS] Limited to first {max_rows} rows", flush=True) | |
| total = len(df) | |
| if progress_callback: | |
| progress_callback(0, total, "downloading") | |
| rows = [] | |
| for i, row in enumerate(df.to_dict(orient="records")): | |
| rows.append(row) | |
| if progress_callback: | |
| progress_callback(i + 1, total, "downloading") | |
| return rows, total | |
| except Exception as e: | |
| raise e | |
| def _load_sample_rows(dataset_id, sample_size, load_kwargs): | |
| """Load just a few rows for format detection.""" | |
| try: | |
| ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) | |
| return [next(iter(ds)) for _ in range(sample_size)] | |
| except Exception: | |
| pass | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| return df.head(sample_size).to_dict(orient="records") | |
| except Exception: | |
| return [] | |
| def load_dataset_texts( | |
| dataset_id, | |
| max_samples=None, | |
| sample_size=None, | |
| progress_callback=None, | |
| custom_format=None, | |
| ): | |
| """ | |
| Load a HuggingFace dataset and extract assistant response texts. | |
| Returns: { | |
| "texts": list of extracted texts, | |
| "format": detected format name, | |
| "format_info": format info dict, | |
| "total_rows": total rows in dataset, | |
| "supported": bool, | |
| "error": error message if failed, | |
| } | |
| progress_callback: optional function(current, total, stage) -> None | |
| stage can be: "fetching_info", "downloading", "extracting" | |
| custom_format: optional custom format specification string | |
| Examples: | |
| - "column: response" | |
| - "column: prompt, column: response" | |
| - "pattern: user:, pattern: assistant:" | |
| - "user:[startuser]assistant:[startassistant]" | |
| """ | |
| load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {} | |
| rows = [] | |
| total_rows = 0 | |
| if sample_size: | |
| total_rows = _get_dataset_size(dataset_id, load_kwargs) | |
| if total_rows == 0: | |
| return { | |
| "texts": [], | |
| "format": None, | |
| "format_info": None, | |
| "total_rows": 0, | |
| "supported": False, | |
| "error": "Dataset is empty", | |
| } | |
| rows = _load_sample_rows(dataset_id, sample_size, load_kwargs) | |
| else: | |
| if progress_callback: | |
| try: | |
| rows, total_rows = _streaming_download_with_progress( | |
| dataset_id, load_kwargs, progress_callback, max_samples | |
| ) | |
| except Exception as e: | |
| fallback_error = None | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| if max_samples and max_samples < len(df): | |
| df = df.head(max_samples) | |
| total_rows = len(df) | |
| if progress_callback: | |
| progress_callback(0, total_rows, "downloading") | |
| rows = [] | |
| for i, row in enumerate(df.to_dict(orient="records")): | |
| rows.append(row) | |
| if progress_callback: | |
| progress_callback(i + 1, total_rows, "downloading") | |
| except Exception as e2: | |
| fallback_error = str(e2) | |
| return { | |
| "texts": [], | |
| "format": None, | |
| "format_info": None, | |
| "total_rows": 0, | |
| "supported": False, | |
| "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}", | |
| } | |
| else: | |
| try: | |
| ds = load_dataset(dataset_id, split="train", **load_kwargs) | |
| total_rows = len(ds) | |
| if max_samples and max_samples < total_rows: | |
| total_rows = max_samples | |
| rows = list(ds)[:max_samples] if max_samples else list(ds) | |
| except Exception as e: | |
| fallback_error = None | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| if max_samples and max_samples < len(df): | |
| df = df.head(max_samples) | |
| total_rows = len(df) | |
| rows = df.to_dict(orient="records") | |
| except Exception as e2: | |
| fallback_error = str(e2) | |
| return { | |
| "texts": [], | |
| "format": None, | |
| "format_info": None, | |
| "total_rows": 0, | |
| "supported": False, | |
| "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}", | |
| } | |
| if not rows: | |
| return { | |
| "texts": [], | |
| "format": None, | |
| "format_info": None, | |
| "total_rows": 0, | |
| "supported": False, | |
| "error": "Dataset is empty", | |
| } | |
| detect_rows = rows[:sample_size] if sample_size else rows | |
| custom_format_spec = custom_format | |
| if custom_format_spec and check_custom_format(detect_rows, custom_format_spec): | |
| fmt_name = "custom" | |
| fmt_info = { | |
| "name": "Custom Format", | |
| "description": f"Custom format: {custom_format_spec}", | |
| "examples": [], | |
| } | |
| else: | |
| fmt_name, fmt_info = detect_format(detect_rows, sample_size=sample_size or 10) | |
| if fmt_name is None: | |
| return { | |
| "texts": [], | |
| "format": None, | |
| "format_info": None, | |
| "total_rows": total_rows, | |
| "supported": False, | |
| "error": "Unknown dataset format. Supported formats: " | |
| + ", ".join(f["name"] for f in SUPPORTED_FORMATS.values()), | |
| } | |
| extractors = { | |
| "teichai_healer": extract_texts_teichai_healer, | |
| "teichai": extract_texts_conversations, | |
| "conversations": extract_texts_conversations, | |
| "chat": extract_texts_chat, | |
| "text": lambda r: extract_texts_text_field(r, "text"), | |
| "response": lambda r: extract_texts_text_field(r, "response") | |
| or extract_texts_text_field(r, "output"), | |
| "content": lambda r: extract_texts_text_field(r, "content"), | |
| "messages": extract_texts_messages_array, | |
| "sft": extract_texts_sft, | |
| "qa": extract_texts_qa, | |
| "combined": lambda r: ( | |
| extract_texts_text_field(r, "output") | |
| or extract_texts_text_field(r, "outputs") | |
| or extract_texts_text_field(r, "generated") | |
| or extract_texts_text_field(r, "completion") | |
| or extract_texts_text_field(r, "combined") | |
| or extract_texts_text_field(r, "data") | |
| or extract_texts_text_field(r, "example") | |
| ), | |
| "completion": lambda r: extract_texts_text_field(r, "completion"), | |
| "generations": lambda r: ( | |
| extract_texts_text_field(r, "generations") | |
| or extract_texts_text_field(r, "generation") | |
| ), | |
| "custom": lambda r: extract_texts_custom(r, custom_format_spec), | |
| } | |
| extractor = extractors.get(fmt_name) | |
| texts = extractor(rows) if extractor else [] | |
| if max_samples and len(texts) > max_samples: | |
| random.seed(42) | |
| texts = random.sample(texts, max_samples) | |
| return { | |
| "texts": texts, | |
| "format": fmt_name, | |
| "format_info": fmt_info, | |
| "total_rows": total_rows, | |
| "supported": True, | |
| "error": None, | |
| } | |
| def parse_custom_format_spec(spec): | |
| """ | |
| Parse custom format specification. | |
| Supported formats: | |
| - "column: <field_name>" - extract single field as text | |
| - "column: <user_col>, column: <assistant_col>" - extract from two columns (user/assistant) | |
| - "pattern: <start_marker>user<end_marker>, pattern: <start_marker>assistant<end_marker>" - use regex patterns | |
| - "delimiter: <delim>" - use delimiter to split columns | |
| Examples: | |
| - "column: response" | |
| - "column: prompt, column: response" | |
| - "pattern: user:, pattern: assistant:" | |
| - "user:[startuser]assistant:[startassistant]" | |
| """ | |
| if not spec: | |
| return None | |
| spec = spec.strip() | |
| result = { | |
| "type": None, | |
| "user_field": None, | |
| "assistant_field": None, | |
| "user_pattern": None, | |
| "assistant_pattern": None, | |
| } | |
| if spec.startswith("column:") or spec.startswith("col:"): | |
| cols_spec = spec.replace("column:", "").replace("col:", "").strip() | |
| if "," in cols_spec: | |
| parts = [p.strip() for p in cols_spec.split(",")] | |
| if len(parts) >= 2: | |
| result["type"] = "two_column" | |
| result["user_field"] = parts[0] | |
| result["assistant_field"] = parts[1] | |
| else: | |
| result["type"] = "single_column" | |
| result["assistant_field"] = cols_spec | |
| return result | |
| if spec.startswith("pattern:") or spec.startswith("regex:"): | |
| patterns_spec = spec.replace("pattern:", "").replace("regex:", "").strip() | |
| if "," in patterns_spec: | |
| parts = [p.strip() for p in patterns_spec.split(",")] | |
| if len(parts) >= 2: | |
| result["type"] = "two_pattern" | |
| result["user_pattern"] = parts[0] | |
| result["assistant_pattern"] = parts[1] | |
| else: | |
| result["type"] = "single_pattern" | |
| result["assistant_pattern"] = patterns_spec | |
| return result | |
| if "user:" in spec.lower() and "assistant:" in spec.lower(): | |
| import re | |
| user_match = re.search( | |
| r"user:\s*(\[.*?\]|(?:(?!\s+assistant:).)+)", | |
| spec, | |
| re.IGNORECASE | re.DOTALL, | |
| ) | |
| assistant_match = re.search( | |
| r"assistant:\s*(\[.*?\]|(?:(?:\s+user:|$).)+)", | |
| spec, | |
| re.IGNORECASE | re.DOTALL, | |
| ) | |
| if user_match and assistant_match: | |
| result["type"] = "two_pattern" | |
| result["user_pattern"] = user_match.group(1).strip() | |
| result["assistant_pattern"] = assistant_match.group(1).strip() | |
| return result | |
| if "[startuser]" in spec and "[startassistant]" in spec: | |
| result["type"] = "two_pattern" | |
| result["user_pattern"] = re.escape("[startuser]") | |
| result["assistant_pattern"] = re.escape("[startassistant]") | |
| return result | |
| if "," in spec: | |
| parts = [p.strip() for p in spec.split(",")] | |
| if len(parts) >= 2: | |
| result["type"] = "two_column" | |
| result["user_field"] = parts[0] | |
| result["assistant_field"] = parts[1] | |
| return result | |
| result["type"] = "single_column" | |
| result["assistant_field"] = spec | |
| return result | |
| def extract_texts_custom(rows, format_spec): | |
| """Extract texts using custom format specification.""" | |
| parsed = parse_custom_format_spec(format_spec) | |
| if not parsed or not parsed.get("type"): | |
| return [] | |
| texts = [] | |
| if parsed["type"] == "single_column": | |
| field = parsed["assistant_field"] | |
| for row in rows: | |
| content = row.get(field, "") | |
| if content and len(str(content)) > 50: | |
| response_only = _extract_response_only(str(content)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| elif parsed["type"] == "two_column": | |
| user_field = parsed.get("user_field") | |
| assistant_field = parsed["assistant_field"] | |
| for row in rows: | |
| user_content = row.get(user_field, "") if user_field else "" | |
| assistant_content = row.get(assistant_field, "") | |
| if assistant_content and len(str(assistant_content)) > 50: | |
| response_only = _extract_response_only(str(assistant_content)) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| elif parsed["type"] == "single_pattern": | |
| pattern = parsed.get("assistant_pattern") | |
| if pattern: | |
| try: | |
| regex = re.compile(pattern, re.DOTALL | re.IGNORECASE) | |
| for row in rows: | |
| row_str = str(row) | |
| match = regex.search(row_str) | |
| if match: | |
| content = match.group(1) if match.groups() else match.group(0) | |
| if content and len(content) > 50: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| except re.error: | |
| pass | |
| elif parsed["type"] == "two_pattern": | |
| user_pattern = parsed.get("user_pattern") | |
| assistant_pattern = parsed.get("assistant_pattern") | |
| if assistant_pattern: | |
| try: | |
| user_regex = ( | |
| re.compile(user_pattern, re.DOTALL | re.IGNORECASE) | |
| if user_pattern | |
| else None | |
| ) | |
| assistant_regex = re.compile( | |
| assistant_pattern, re.DOTALL | re.IGNORECASE | |
| ) | |
| for row in rows: | |
| row_str = str(row) | |
| match = assistant_regex.search(row_str) | |
| if match: | |
| content = match.group(1) if match.groups() else match.group(0) | |
| if content and len(content) > 50: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| except re.error: | |
| pass | |
| return texts | |
| def check_custom_format(rows, format_spec): | |
| """Check if custom format applies to the dataset.""" | |
| parsed = parse_custom_format_spec(format_spec) | |
| if not parsed or not parsed.get("type"): | |
| return False | |
| if not rows: | |
| return False | |
| sample = rows[0] | |
| if parsed["type"] == "single_column": | |
| return parsed.get("assistant_field") in sample | |
| if parsed["type"] == "two_column": | |
| return parsed.get("assistant_field") in sample | |
| if parsed["type"] in ("single_pattern", "two_pattern"): | |
| pattern = parsed.get("assistant_pattern") | |
| if pattern: | |
| try: | |
| regex = re.compile(pattern, re.DOTALL | re.IGNORECASE) | |
| return regex.search(str(sample)) is not None | |
| except re.error: | |
| pass | |
| return False | |
| def get_supported_formats(): | |
| """Return list of supported format info.""" | |
| return SUPPORTED_FORMATS | |