Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| FreshQA ์ ํ๋ ๊ณ์ฐ ์คํฌ๋ฆฝํธ | |
| ์ด ์คํฌ๋ฆฝํธ๋ FreshQA ๋ฐ์ดํฐ์ ์ ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ ๋ค์ํ ์นดํ ๊ณ ๋ฆฌ๋ณ๋ก ๋ถ์ํฉ๋๋ค. | |
| """ | |
| import pandas as pd | |
| import sys | |
| import os | |
| def load_freshqa_data(csv_path='freshqa.csv'): | |
| """FreshQA CSV ํ์ผ์ ๋ก๋ํฉ๋๋ค.""" | |
| try: | |
| # ๋จผ์ ํ์ผ์ ์ฝ์ด์ ๊ตฌ์กฐ๋ฅผ ํ์ธ | |
| temp_df = pd.read_csv(csv_path) | |
| # print(f"ํ์ผ ๊ตฌ์กฐ ํ์ธ: {len(temp_df)}๊ฐ ํ, ์ปฌ๋ผ: {temp_df.columns.tolist()}") | |
| # rating ์ปฌ๋ผ์ด ์์ผ๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ, ์์ผ๋ฉด skiprows ์ ์ฉ | |
| if 'rating' in temp_df.columns: | |
| fresh_qa = temp_df | |
| # print("rating ์ปฌ๋ผ์ด ์๋ ํ์ผ๋ก ์ธ์ํ์ฌ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํฉ๋๋ค.") | |
| else: | |
| fresh_qa = pd.read_csv(csv_path, skiprows=[0, 1]) | |
| # print("๊ธฐ๋ณธ FreshQA ํ์์ผ๋ก ์ธ์ํ์ฌ skiprows๋ฅผ ์ ์ฉํฉ๋๋ค.") | |
| # print(f"FreshQA ๋ฐ์ดํฐ ๋ก๋ ์๋ฃ: {len(fresh_qa)}๊ฐ ์ํ") | |
| return fresh_qa | |
| except FileNotFoundError: | |
| print(f"์ค๋ฅ: {csv_path} ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| print("ํ์ฌ ๋๋ ํ ๋ฆฌ์ freshqa.csv ํ์ผ์ด ์๋์ง ํ์ธํด์ฃผ์ธ์.") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| sys.exit(1) | |
| def process_freshqa_dataframe(df): | |
| """DataFrame์ ๋งค๊ฐ๋ณ์๋ก ๋ฐ์ FreshQA ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.""" | |
| try: | |
| # print(f"DataFrame ๊ตฌ์กฐ ํ์ธ: {len(df)}๊ฐ ํ, ์ปฌ๋ผ: {df.columns.tolist()}") | |
| # rating ์ปฌ๋ผ์ด ์์ผ๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ, ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ค์ | |
| if 'rating' in df.columns: | |
| # print("DataFrame์ rating ์ปฌ๋ผ์ด ์์ด์ ๊ทธ๋๋ก ์ฌ์ฉํฉ๋๋ค.") | |
| processed_df = df.copy() | |
| else: | |
| # print("DataFrame์ rating ์ปฌ๋ผ์ด ์์ด์ ๊ธฐ๋ณธ๊ฐ 0์ผ๋ก ์ค์ ํฉ๋๋ค.") | |
| processed_df = df.copy() | |
| processed_df['rating'] = 0 # ๊ธฐ๋ณธ๊ฐ์ผ๋ก 0 ์ค์ | |
| # print(f"FreshQA ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์๋ฃ: {len(processed_df)}๊ฐ ์ํ") | |
| return processed_df | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| raise | |
| def update_results(df, d_acc, d_count, field_name): | |
| """๊ฒฐ๊ณผ ๋์ ๋๋ฆฌ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.""" | |
| if len(df) == 0: | |
| r = 0.0 | |
| else: | |
| # rating์ด ๋ฌธ์์ด 'TRUE'์ด๊ฑฐ๋ ์ซ์ 1์ธ ๊ฒฝ์ฐ๋ฅผ ๋ชจ๋ ์ฒ๋ฆฌ | |
| if df['rating'].dtype == 'object': | |
| # ๋ฌธ์์ด์ธ ๊ฒฝ์ฐ 'TRUE' ํ์ธ | |
| r = len(df[df.rating == 'TRUE']) * 100 / len(df) | |
| else: | |
| # ์ซ์์ธ ๊ฒฝ์ฐ 1 ํ์ธ | |
| r = len(df[df.rating == 1]) * 100 / len(df) | |
| d_acc[field_name] = r | |
| d_count[field_name] = len(df) | |
| def calculate_accuracy_simple(fresh_qa): | |
| """FreshQA ๋ฐ์ดํฐ์ ๊ธฐ๋ณธ ์ ํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค (๊ฐ๋จํ ๋ฒ์ ).""" | |
| print("์ ํ๋ ๊ณ์ฐ ์ค...") | |
| # rating ์ปฌ๋ผ์ด ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ 0์ผ๋ก ์ค์ | |
| if 'rating' not in fresh_qa.columns: | |
| # print("rating ์ปฌ๋ผ์ด ์์ด์ ๊ธฐ๋ณธ๊ฐ 0์ผ๋ก ์ค์ ํฉ๋๋ค.") | |
| fresh_qa['rating'] = 0 | |
| accs = {} | |
| counts = {} | |
| # ์ ์ฒด ์ ํ๋ | |
| update_results(fresh_qa, accs, counts, 'overall_accuracy') | |
| # split ์ปฌ๋ผ์ด ์์ผ๋ฉด ๋ถํ ๋ณ ์ ํ๋ ๊ณ์ฐ | |
| if 'split' in fresh_qa.columns: | |
| fresh_qa_test = fresh_qa[fresh_qa.split == 'TEST'] | |
| fresh_qa_dev = fresh_qa[fresh_qa.split == 'DEV'] | |
| update_results(fresh_qa_test, accs, counts, 'acc_test') | |
| update_results(fresh_qa_dev, accs, counts, 'acc_dev') | |
| # fact_type ์ปฌ๋ผ์ด ์์ผ๋ฉด ์ฌ์ค ์ ํ๋ณ ์ ํ๋ ๊ณ์ฐ | |
| if 'fact_type' in fresh_qa.columns: | |
| for fact_type in ['fast-changing', 'slow-changing', 'never-changing']: | |
| if fact_type in fresh_qa['fact_type'].values: | |
| sub_df = fresh_qa[fresh_qa.fact_type == fact_type] | |
| update_results(sub_df, accs, counts, f'{fact_type}_accuracy') | |
| # false_premise ์ปฌ๋ผ์ด ์์ผ๋ฉด False premise ์ ํ๋ ๊ณ์ฐ | |
| if 'false_premise' in fresh_qa.columns: | |
| fp_df = fresh_qa[fresh_qa.false_premise == True] | |
| if len(fp_df) > 0: | |
| update_results(fp_df, accs, counts, 'false_premise_accuracy') | |
| # domain ์ปฌ๋ผ์ด ์์ผ๋ฉด ๋๋ฉ์ธ๋ณ ์ ํ๋ ๊ณ์ฐ | |
| if 'domain' in fresh_qa.columns: | |
| # ํ๊ตญ์ด ๋๋ฉ์ธ ์นดํ ๊ณ ๋ฆฌ๋ค (์ค์ CSV ํ์ผ์ domain ๊ฐ๋ค) | |
| korean_domains = ['์ ์น', '์คํฌ์ธ ', '์ฐ์', '๋ ์จ', '์ธ๊ณ', '๊ฒฝ์ ', '์ฌํ', 'IT/๊ณผํ', '์ํ/๋ฌธํ', 'UNK'] | |
| for domain in korean_domains: | |
| if domain in fresh_qa['domain'].values: | |
| domain_df = fresh_qa[fresh_qa.domain == domain] | |
| domain_test = domain_df[domain_df.split == 'TEST'] | |
| domain_dev = domain_df[domain_df.split == 'DEV'] | |
| # ๋๋ฉ์ธ๋ช ์ ์์ด๋ก ๋ณํ (ํ์ผ๋ช /ํค์ ์ฌ์ฉ) | |
| domain_key = domain.replace('/', '_').replace(' ', '_').lower() | |
| if domain == 'IT/๊ณผํ': | |
| domain_key = 'it_science' | |
| elif domain == '์ํ/๋ฌธํ': | |
| domain_key = 'life_culture' | |
| elif domain == 'UNK': | |
| domain_key = 'unknown' | |
| update_results(domain_df, accs, counts, f'acc_{domain_key}') | |
| update_results(domain_test, accs, counts, f'acc_test_{domain_key}') | |
| update_results(domain_dev, accs, counts, f'acc_dev_{domain_key}') | |
| # ๊ธฐ์กด ์์ด ๋๋ฉ์ธ๋ค๋ ์ ์ง (ํธํ์ฑ์ ์ํด) | |
| english_domains = ['politics', 'sports', 'entertainment', 'weather', 'world', 'economy', 'society', 'it_science', 'life_culture'] | |
| for domain in english_domains: | |
| if domain in fresh_qa['domain'].values: | |
| domain_df = fresh_qa[fresh_qa.domain == domain] | |
| update_results(domain_df, accs, counts, f'{domain}_accuracy') | |
| # ์ด ์ง๋ฌธ ์ ์ถ๊ฐ | |
| accs['total_questions'] = len(fresh_qa) | |
| return accs | |
| def calculate_accuracy(fresh_qa): | |
| """FreshQA ๋ฐ์ดํฐ์ ์ ํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.""" | |
| # ๋ฐ์ดํฐ ๋ถํ | |
| fresh_qa_test = fresh_qa[fresh_qa.split == 'TEST'] | |
| fresh_qa_dev = fresh_qa[fresh_qa.split == 'DEV'] | |
| accs = {} | |
| counts = {} | |
| # ์ ์ฒด ์ ํ๋ | |
| update_results(fresh_qa, accs, counts, 'acc') | |
| update_results(fresh_qa_test, accs, counts, 'acc_test') | |
| update_results(fresh_qa_dev, accs, counts, 'acc_dev') | |
| # ์ฌ์ค ์ ํ๋ณ ์ ํ๋ | |
| for fact_type in ['fast-changing', 'slow-changing', 'never-changing']: | |
| sub_df = fresh_qa[(fresh_qa.false_premise == False) & (fresh_qa.fact_type == fact_type)] | |
| sub_df_test = sub_df[sub_df.split == 'TEST'] | |
| sub_df_dev = sub_df[sub_df.split == 'DEV'] | |
| ft = fact_type.replace('-', '_') | |
| update_results(sub_df, accs, counts, f'acc_{ft}') | |
| update_results(sub_df_test, accs, counts, f'acc_test_{ft}') | |
| update_results(sub_df_dev, accs, counts, f'acc_dev_{ft}') | |
| # ์ง๋ฌธ ์ ํ๋ณ ์ ํ๋ (vp: valid premise, fp: false premise) | |
| for qt in ['vp', 'fp']: | |
| fp = True if qt == 'fp' else False | |
| data = fresh_qa[(fresh_qa.false_premise == fp)] | |
| data_test = data[data.split == 'TEST'] | |
| data_dev = data[data.split == 'DEV'] | |
| # ํ ์๋ณ ๋ถ์ | |
| data_one_hop = data[data.num_hops == 'one-hop'] | |
| data_one_hop_test = data_one_hop[data_one_hop.split == 'TEST'] | |
| data_one_hop_dev = data_one_hop[data_one_hop.split == 'DEV'] | |
| data_two_hop = data[data.num_hops == 'multi-hop'] | |
| data_two_hop_test = data_two_hop[data_two_hop.split == 'TEST'] | |
| data_two_hop_dev = data_two_hop[data_two_hop.split == 'DEV'] | |
| # ์ฐ๋๋ณ ๋ถ์ | |
| data_old = data[(data.effective_year != '2022') & (data.effective_year != '2023')] | |
| data_old_test = data_old[data_old.split == 'TEST'] | |
| data_old_dev = data_old[data_old.split == 'DEV'] | |
| data_new = data[(data.effective_year == '2022') | (data.effective_year == '2023')] | |
| data_new_test = data_new[data_new.split == 'TEST'] | |
| data_new_dev = data_new[data_new.split == 'DEV'] | |
| # ๊ธฐ๋ณธ ์ ํ๋ | |
| update_results(data, accs, counts, f'acc_{qt}') | |
| update_results(data_test, accs, counts, f'acc_test_{qt}') | |
| update_results(data_dev, accs, counts, f'acc_dev_{qt}') | |
| # ํ ์๋ณ ์ ํ๋ | |
| update_results(data_one_hop, accs, counts, f'acc_{qt}_one_hop') | |
| update_results(data_one_hop_test, accs, counts, f'acc_test_{qt}_one_hop') | |
| update_results(data_one_hop_dev, accs, counts, f'acc_dev_{qt}_one_hop') | |
| update_results(data_two_hop, accs, counts, f'acc_{qt}_two_hop') | |
| update_results(data_two_hop_test, accs, counts, f'acc_test_{qt}_two_hop') | |
| update_results(data_two_hop_dev, accs, counts, f'acc_dev_{qt}_two_hop') | |
| # ์ฐ๋๋ณ ์ ํ๋ | |
| update_results(data_old, accs, counts, f'acc_{qt}_old') | |
| update_results(data_old_test, accs, counts, f'acc_test_{qt}_old') | |
| update_results(data_old_dev, accs, counts, f'acc_dev_{qt}_old') | |
| update_results(data_new, accs, counts, f'acc_{qt}_new') | |
| update_results(data_new_test, accs, counts, f'acc_test_{qt}_new') | |
| update_results(data_new_dev, accs, counts, f'acc_dev_{qt}_new') | |
| # ๋๋ฉ์ธ๋ณ ์ ํ๋ ๊ณ์ฐ | |
| if 'domain' in fresh_qa.columns: | |
| # ํ๊ตญ์ด ๋๋ฉ์ธ ์นดํ ๊ณ ๋ฆฌ๋ค (์ค์ CSV ํ์ผ์ domain ๊ฐ๋ค) | |
| korean_domains = ['์ ์น', '์คํฌ์ธ ', '์ฐ์', '๋ ์จ', '์ธ๊ณ', '๊ฒฝ์ ', '์ฌํ', 'IT/๊ณผํ', '์ํ/๋ฌธํ', 'UNK'] | |
| # ๋๋ฉ์ธ๋ช ์ ์์ด๋ก ๋ณํ (ํ์ผ๋ช /ํค์ ์ฌ์ฉ) | |
| domain_mapping = { | |
| '์ ์น': 'politics', | |
| '์คํฌ์ธ ': 'sports', | |
| '์ฐ์': 'entertainment', | |
| '๋ ์จ': 'weather', | |
| '์ธ๊ณ': 'world', | |
| '๊ฒฝ์ ': 'economy', | |
| '์ฌํ': 'society', | |
| 'IT/๊ณผํ': 'it_science', | |
| '์ํ/๋ฌธํ': 'life_culture', | |
| 'UNK': 'unknown' | |
| } | |
| for domain in korean_domains: | |
| if domain in fresh_qa['domain'].values: | |
| domain_df = fresh_qa[fresh_qa.domain == domain] | |
| domain_test = domain_df[domain_df.split == 'TEST'] | |
| domain_dev = domain_df[domain_df.split == 'DEV'] | |
| domain_key = domain_mapping.get(domain, domain.replace('/', '_').replace(' ', '_').lower()) | |
| update_results(domain_df, accs, counts, f'acc_{domain_key}') | |
| update_results(domain_test, accs, counts, f'acc_test_{domain_key}') | |
| update_results(domain_dev, accs, counts, f'acc_dev_{domain_key}') | |
| return accs, counts | |
| def print_results(accs, counts): | |
| """๊ฒฐ๊ณผ๋ฅผ ๋ณด๊ธฐ ์ข๊ฒ ์ถ๋ ฅํฉ๋๋ค.""" | |
| print("\n" + "="*80) | |
| print("FreshQA ์ ํ๋ ๋ถ์ ๊ฒฐ๊ณผ") | |
| print("="*80) | |
| # ์ ์ฒด ์ ํ๋ | |
| print(f"\n๐ ์ ์ฒด ์ ํ๋:") | |
| print(f" ์ ์ฒด: {accs['acc']}% ({counts['acc']}๊ฐ ์ํ)") | |
| print(f" ํ ์คํธ: {accs['acc_test']}% ({counts['acc_test']}๊ฐ ์ํ)") | |
| print(f" ๊ฐ๋ฐ: {accs['acc_dev']}% ({counts['acc_dev']}๊ฐ ์ํ)") | |
| # ์ฌ์ค ์ ํ๋ณ ์ ํ๋ | |
| print(f"\n๐ ์ฌ์ค ์ ํ๋ณ ์ ํ๋:") | |
| fact_types = { | |
| 'fast_changing': '๋น ๋ฅด๊ฒ ๋ณํ๋ ์ฌ์ค', | |
| 'slow_changing': '์ฒ์ฒํ ๋ณํ๋ ์ฌ์ค', | |
| 'never_changing': '๋ณํ์ง ์๋ ์ฌ์ค' | |
| } | |
| for key, name in fact_types.items(): | |
| print(f" {name}:") | |
| print(f" ์ ์ฒด: {accs[f'acc_{key}']}% ({counts[f'acc_{key}']}๊ฐ ์ํ)") | |
| print(f" ํ ์คํธ: {accs[f'acc_test_{key}']}% ({counts[f'acc_test_{key}']}๊ฐ ์ํ)") | |
| print(f" ๊ฐ๋ฐ: {accs[f'acc_dev_{key}']}% ({counts[f'acc_dev_{key}']}๊ฐ ์ํ)") | |
| # ์ง๋ฌธ ์ ํ๋ณ ์ ํ๋ | |
| print(f"\nโ ์ง๋ฌธ ์ ํ๋ณ ์ ํ๋:") | |
| question_types = { | |
| 'vp': '์ ํจํ ์ ์ (Valid Premise)', | |
| 'fp': '์๋ชป๋ ์ ์ (False Premise)' | |
| } | |
| for key, name in question_types.items(): | |
| print(f" {name}:") | |
| print(f" ์ ์ฒด: {accs[f'acc_{key}']}% ({counts[f'acc_{key}']}๊ฐ ์ํ)") | |
| print(f" ํ ์คํธ: {accs[f'acc_test_{key}']}% ({counts[f'acc_test_{key}']}๊ฐ ์ํ)") | |
| print(f" ๊ฐ๋ฐ: {accs[f'acc_dev_{key}']}% ({counts[f'acc_dev_{key}']}๊ฐ ์ํ)") | |
| # ํ ์๋ณ | |
| print(f" ๋จ์ผ ํ: {accs[f'acc_{key}_one_hop']}% ({counts[f'acc_{key}_one_hop']}๊ฐ ์ํ)") | |
| print(f" ๋ค์ค ํ: {accs[f'acc_{key}_two_hop']}% ({counts[f'acc_{key}_two_hop']}๊ฐ ์ํ)") | |
| # ์ฐ๋๋ณ | |
| print(f" ์ค๋๋ ๋ฐ์ดํฐ: {accs[f'acc_{key}_old']}% ({counts[f'acc_{key}_old']}๊ฐ ์ํ)") | |
| print(f" ์ต์ ๋ฐ์ดํฐ: {accs[f'acc_{key}_new']}% ({counts[f'acc_{key}_new']}๊ฐ ์ํ)") | |
| # ๋๋ฉ์ธ๋ณ ์ ํ๋ | |
| print(f"\n๐ ๋๋ฉ์ธ๋ณ ์ ํ๋:") | |
| domain_mapping = { | |
| 'politics': '์ ์น', | |
| 'sports': '์คํฌ์ธ ', | |
| 'entertainment': '์ฐ์', | |
| 'weather': '๋ ์จ', | |
| 'world': '์ธ๊ณ', | |
| 'economy': '๊ฒฝ์ ', | |
| 'society': '์ฌํ', | |
| 'it_science': 'IT/๊ณผํ', | |
| 'life_culture': '์ํ/๋ฌธํ', | |
| 'unknown': 'UNK' | |
| } | |
| for key, name in domain_mapping.items(): | |
| if f'acc_{key}' in accs: | |
| print(f" {name}:") | |
| print(f" ์ ์ฒด: {accs[f'acc_{key}']}% ({counts[f'acc_{key}']}๊ฐ ์ํ)") | |
| if f'acc_test_{key}' in accs: | |
| print(f" ํ ์คํธ: {accs[f'acc_test_{key}']}% ({counts[f'acc_test_{key}']}๊ฐ ์ํ)") | |
| pass | |
| if f'acc_dev_{key}' in accs: | |
| print(f" ๊ฐ๋ฐ: {accs[f'acc_dev_{key}']}% ({counts[f'acc_dev_{key}']}๊ฐ ์ํ)") | |
| pass | |
| pass | |
| print("\n" + "="*80) | |
| def main(): | |
| """๋ฉ์ธ ํจ์""" | |
| print("FreshQA ์ ํ๋ ๊ณ์ฐ ์คํฌ๋ฆฝํธ") | |
| print("="*50) | |
| # CSV ํ์ผ ๊ฒฝ๋ก ํ์ธ | |
| csv_path = 'freshqa.csv' | |
| if len(sys.argv) > 1: | |
| csv_path = sys.argv[1] | |
| if not os.path.exists(csv_path): | |
| print(f"์ค๋ฅ: {csv_path} ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| print("์ฌ์ฉ๋ฒ: python freshqa_acc.py [csv_file_path]") | |
| sys.exit(1) | |
| # ๋ฐ์ดํฐ ๋ก๋ | |
| fresh_qa = load_freshqa_data(csv_path) | |
| # ์ ํ๋ ๊ณ์ฐ | |
| accs, counts = calculate_accuracy(fresh_qa) | |
| # ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| print_results(accs, counts) | |
| # ๋์ ๋๋ฆฌ ํํ๋ก๋ ์ถ๋ ฅ (์๋ณธ ๋ ธํธ๋ถ๊ณผ ๋์ผ) | |
| print(f"\n๐ ๋์ ๋๋ฆฌ ํํ ๊ฒฐ๊ณผ:") | |
| print(accs) | |
| if __name__ == "__main__": | |
| main() | |