Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import string | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| from typing import List, Tuple, Dict | |
| def read_lines(fname: str) -> List[str]: | |
| """ | |
| Reads all lines from an input file and returns them as a list of strings. | |
| Args: | |
| fname (str): path to the input file to read | |
| Returns: | |
| List[str]: a list of strings, where each string is a line from the file | |
| and returns an empty list if the file does not exist. | |
| """ | |
| # if path doesnt exist, return empty list | |
| if not os.path.exists(fname): | |
| return [] | |
| with open(fname, "r") as f: | |
| lines = f.readlines() | |
| return lines | |
| def create_txt(out_file: str, lines: List[str]): | |
| """ | |
| Creates a text file and writes the given list of lines to file. | |
| Args: | |
| out_file (str): path to the output file to be created. | |
| lines (List[str]): a list of strings to be written to the output file. | |
| """ | |
| add_newline = not "\n" in lines[0] | |
| outfile = open("{}".format(out_file), "w", encoding="utf-8") | |
| for line in lines: | |
| if add_newline: | |
| outfile.write(line + "\n") | |
| else: | |
| outfile.write(line) | |
| outfile.close() | |
| def pair_dedup_lists(src_list: List[str], tgt_list: List[str]) -> Tuple[List[str], List[str]]: | |
| """ | |
| Removes duplicates from two lists by pairing their elements and removing duplicates from the pairs. | |
| Args: | |
| src_list (List[str]): a list of strings from source language data. | |
| tgt_list (List[str]): a list of strings from target language data. | |
| Returns: | |
| Tuple[List[str], List[str]]: a tuple of deduplicated version of "`(src_list, tgt_list)`". | |
| """ | |
| src_tgt = list(set(zip(src_list, tgt_list))) | |
| src_deduped, tgt_deduped = zip(*src_tgt) | |
| return src_deduped, tgt_deduped | |
| def pair_dedup_files(src_file: str, tgt_file: str): | |
| """ | |
| Removes duplicates from two files by pairing their lines and removing duplicates from the pairs. | |
| Args: | |
| src_file (str): path to the source language file to deduplicate. | |
| tgt_file (str): path to the target language file to deduplicate. | |
| """ | |
| src_lines = read_lines(src_file) | |
| tgt_lines = read_lines(tgt_file) | |
| len_before = len(src_lines) | |
| src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines) | |
| len_after = len(src_dedupped) | |
| num_duplicates = len_before - len_after | |
| print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}") | |
| create_txt(src_file, src_dedupped) | |
| create_txt(tgt_file, tgt_dedupped) | |
| def strip_and_normalize(line: str) -> str: | |
| """ | |
| Strips and normalizes a string by lowercasing it, removing spaces and punctuation. | |
| Args: | |
| line (str): string to strip and normalize. | |
| Returns: | |
| str: stripped and normalized version of the input string. | |
| """ | |
| # lowercase line, remove spaces and strip punctuation | |
| # one of the fastest way to add an exclusion list and remove that | |
| # list of characters from a string | |
| # https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb | |
| exclist = string.punctuation + "\u0964" | |
| table_ = str.maketrans("", "", exclist) | |
| line = line.replace(" ", "").lower() | |
| # dont use this method, it is painfully slow | |
| # line = "".join([i for i in line if i not in string.punctuation]) | |
| line = line.translate(table_) | |
| return line | |
| def expand_tupled_list(list_of_tuples: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]: | |
| """ | |
| Expands a list of tuples into two lists by extracting the first and second elements of the tuples. | |
| Args: | |
| list_of_tuples (List[Tuple[str, str]]): a list of tuples, where each tuple contains two strings. | |
| Returns: | |
| Tuple[List[str], List[str]]: a tuple containing two lists, the first being the first elements of the | |
| tuples in `list_of_tuples` and the second being the second elements. | |
| """ | |
| # convert list of tuples into two lists | |
| # https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists | |
| # [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu] | |
| list_a, list_b = map(list, zip(*list_of_tuples)) | |
| return list_a, list_b | |
| def normalize_and_gather_all_benchmarks(devtest_dir: str) -> Dict[str, Dict[str, List[str]]]: | |
| """ | |
| Normalizes and gathers all benchmark datasets from a directory into a dictionary. | |
| Args: | |
| devtest_dir (str): path to the directory containing the subdirectories named after the benchmark datasets, \ | |
| where each subdirectory is named in the format "`src_lang-tgt_lang`" and contain four files: `dev.src_lang`, \ | |
| `dev.tgt_lang`, `test.src_lang`, and `test.tgt_lang` representing the development and test sets for the language pair. | |
| Returns: | |
| Dict[str, Dict[str, List[str]]]: a dictionary mapping language pairs (in the format "`src_lang-tgt_lang`") \ | |
| to dictionaries containing two lists, the first being the normalized source language lines and the \ | |
| second being the normalized target language lines for all benchmark datasets. | |
| """ | |
| devtest_pairs_normalized = defaultdict(lambda: defaultdict(list)) | |
| for benchmark in os.listdir(devtest_dir): | |
| print(f"{devtest_dir}/{benchmark}") | |
| for pair in tqdm(os.listdir(f"{devtest_dir}/{benchmark}")): | |
| src_lang, tgt_lang = pair.split("-") | |
| src_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{src_lang}") | |
| tgt_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{tgt_lang}") | |
| src_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{src_lang}") | |
| tgt_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{tgt_lang}") | |
| # if the tgt_pair data doesnt exist for a particular test set, | |
| # it will be an empty list | |
| if tgt_test == [] or tgt_dev == []: | |
| print(f"{benchmark} does not have {src_lang}-{tgt_lang} data") | |
| continue | |
| # combine both dev and test sets into one | |
| src_devtest = src_dev + src_test | |
| tgt_devtest = tgt_dev + tgt_test | |
| src_devtest = [strip_and_normalize(line) for line in src_devtest] | |
| tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest] | |
| devtest_pairs_normalized[pair]["src"].extend(src_devtest) | |
| devtest_pairs_normalized[pair]["tgt"].extend(tgt_devtest) | |
| # dedup merged benchmark datasets | |
| for pair in devtest_pairs_normalized: | |
| src_devtest = devtest_pairs_normalized[pair]["src"] | |
| tgt_devtest = devtest_pairs_normalized[pair]["tgt"] | |
| src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest) | |
| devtest_pairs_normalized[pair]["src"] = src_devtest | |
| devtest_pairs_normalized[pair]["tgt"] = tgt_devtest | |
| return devtest_pairs_normalized | |
| def remove_train_devtest_overlaps(train_dir: str, devtest_dir: str): | |
| """ | |
| Removes overlapping data between the training and dev/test (benchmark) | |
| datasets for all language pairs. | |
| Args: | |
| train_dir (str): path of the directory containing the training data. | |
| devtest_dir (str): path of the directory containing the dev/test data. | |
| """ | |
| devtest_pairs_normalized = normalize_and_gather_all_benchmarks(devtest_dir) | |
| all_src_sentences_normalized = [] | |
| for key in devtest_pairs_normalized: | |
| all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"]) | |
| # remove duplicates in all test benchmarks across all lang pair | |
| # this might not be the most optimal way but this is a tradeoff for generalizing the code at the moment | |
| all_src_sentences_normalized = list(set(all_src_sentences_normalized)) | |
| src_overlaps = [] | |
| tgt_overlaps = [] | |
| pairs = os.listdir(train_dir) | |
| for pair in pairs: | |
| src_lang, tgt_lang = pair.split("-") | |
| new_src_train, new_tgt_train = [], [] | |
| src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}") | |
| tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}") | |
| len_before = len(src_train) | |
| if len_before == 0: | |
| continue | |
| src_train_normalized = [strip_and_normalize(line) for line in src_train] | |
| tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train] | |
| src_devtest_normalized = all_src_sentences_normalized | |
| tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"] | |
| # compute all src and tgt super strict overlaps for a lang pair | |
| overlaps = set(src_train_normalized) & set(src_devtest_normalized) | |
| src_overlaps.extend(list(overlaps)) | |
| overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized) | |
| tgt_overlaps.extend(list(overlaps)) | |
| # dictionaries offer O(1) lookup | |
| src_overlaps_dict, tgt_overlaps_dict = {}, {} | |
| for line in src_overlaps: | |
| src_overlaps_dict[line] = 1 | |
| for line in tgt_overlaps: | |
| tgt_overlaps_dict[line] = 1 | |
| # loop to remove the ovelapped data | |
| idx = 0 | |
| for src_line_norm, tgt_line_norm in tqdm( | |
| zip(src_train_normalized, tgt_train_normalized), total=len_before | |
| ): | |
| if src_overlaps_dict.get(src_line_norm, None): | |
| continue | |
| if tgt_overlaps_dict.get(tgt_line_norm, None): | |
| continue | |
| new_src_train.append(src_train[idx]) | |
| new_tgt_train.append(tgt_train[idx]) | |
| idx += 1 | |
| len_after = len(new_src_train) | |
| print( | |
| f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}" | |
| ) | |
| print(f"saving new files at {train_dir}/{pair}/") | |
| create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train) | |
| create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train) | |
| if __name__ == "__main__": | |
| train_data_dir = sys.argv[1] | |
| # benchmarks directory should contains all the test sets | |
| devtest_data_dir = sys.argv[2] | |
| remove_train_devtest_overlaps(train_data_dir, devtest_data_dir) | |