Spaces:
Runtime error
Runtime error
| import argparse | |
| import re | |
| from pathlib import Path | |
| import nbformat | |
| import nbconvert | |
| from traitlets.config import Config | |
| # Notebooks that are excluded from the CI tests | |
| EXCLUDED_NOTEBOOKS = ["data-preparation-ct-scan.ipynb", "pytorch-monai-training.ipynb"] | |
| DEVICE_WIDGET = "device = widgets.Dropdown(" | |
| def disable_gradio_debug(nb, notebook_path): | |
| found = False | |
| for cell in nb["cells"]: | |
| if "gradio" in cell["source"] and "debug" in cell["source"]: | |
| found = True | |
| cell["source"] = cell["source"].replace("debug=True", "debug=False") | |
| if found: | |
| print(f"Disabled gradio debug mode for {notebook_path}") | |
| return nb | |
| def disable_skip_ext(nb, notebook_path, test_device=""): | |
| found = False | |
| skip_for_device = None if test_device else False | |
| for cell in nb["cells"]: | |
| if test_device is not None and skip_for_device is None: | |
| if ( | |
| 'skip_for_device = "{}" in device.value'.format(test_device) in cell["source"] | |
| and "to_quantize = widgets.Checkbox(value=not skip_for_device" in cell["source"] | |
| ): | |
| skip_for_device = True | |
| if "%%skip" in cell["source"]: | |
| found = True | |
| if not skip_for_device: | |
| cell["source"] = re.sub(r"%%skip.*.\n", "\n", cell["source"]) | |
| else: | |
| cell["source"] = '"""\n' + cell["source"] + '\n"""' | |
| if found: | |
| print(f"Disabled skip extension mode for {notebook_path}") | |
| return nb | |
| def remove_ov_install(cell): | |
| updated_lines = [] | |
| def has_additional_deps(str_part): | |
| if "%pip" in str_part: | |
| return False | |
| if "install" in str_part: | |
| return False | |
| if str_part.startswith("-"): | |
| return False | |
| if str_part.startswith("https://"): | |
| return False | |
| return True | |
| lines = cell["source"].split("\n") | |
| for line in lines: | |
| if "openvino" in line: | |
| updated_line_content = [] | |
| empty = True | |
| package_found = False | |
| for part in line.split(" "): | |
| if "openvino-dev" in part: | |
| package_found = True | |
| continue | |
| if "openvino-nightly" in part: | |
| package_found = True | |
| continue | |
| if "openvino-tokenizers" in part: | |
| package_found = True | |
| continue | |
| if "openvino>" in part or "openvino=" in part or "openvino" == part: | |
| package_found = True | |
| continue | |
| if empty: | |
| empty = not has_additional_deps(part) | |
| updated_line_content.append(part) | |
| if package_found: | |
| if not empty: | |
| updated_line = " ".join(updated_line_content) | |
| if line.startswith(" "): | |
| for token in line: | |
| if token != " ": | |
| break | |
| # keep indention | |
| updated_line = " " + updated_line | |
| updated_lines.append(updated_line + "\n# " + line) | |
| else: | |
| updated_lines.append(line) | |
| else: | |
| updated_lines.append(line) | |
| cell["source"] = "\n".join(updated_lines) | |
| def patch_notebooks(notebooks_dir, test_device="", skip_ov_install=False): | |
| """ | |
| Patch notebooks in notebooks directory with replacement values | |
| found in notebook metadata to speed up test execution. | |
| This function is specific for the OpenVINO notebooks | |
| Github Actions CI. | |
| For example: change nr of epochs from 15 to 1 in | |
| tensorflow-training-openvino-nncf.ipynb by adding | |
| {"test_replace": {"epochs = 15": "epochs = 1"} to the cell | |
| metadata of the cell that contains `epochs = 15` | |
| :param notebooks_dir: Directory that contains the notebook subdirectories. | |
| For example: openvino_notebooks/notebooks | |
| """ | |
| nb_convert_config = Config() | |
| nb_convert_config.NotebookExporter.preprocessors = ["nbconvert.preprocessors.ClearOutputPreprocessor"] | |
| output_remover = nbconvert.NotebookExporter(nb_convert_config) | |
| for notebookfile in Path(notebooks_dir).glob("**/*.ipynb"): | |
| if not str(notebookfile.name).startswith("test_") and notebookfile.name not in EXCLUDED_NOTEBOOKS: | |
| nb = nbformat.read(notebookfile, as_version=nbformat.NO_CONVERT) | |
| found = False | |
| device_found = False | |
| for cell in nb["cells"]: | |
| if skip_ov_install and "%pip" in cell["source"]: | |
| remove_ov_install(cell) | |
| if test_device and DEVICE_WIDGET in cell["source"]: | |
| device_found = True | |
| cell["source"] = re.sub(r"value=.*,", f"value='{test_device.upper()}',", cell["source"]) | |
| cell["source"] = re.sub( | |
| r"options=", | |
| f"options=['{test_device.upper()}'] + ", | |
| cell["source"], | |
| ) | |
| print(f"Replaced testing device to {test_device}") | |
| replace_dict = cell.get("metadata", {}).get("test_replace") | |
| if replace_dict is not None: | |
| found = True | |
| for source_value, target_value in replace_dict.items(): | |
| if source_value not in cell["source"]: | |
| raise ValueError(f"Processing {notebookfile} failed: {source_value} does not exist in cell") | |
| cell["source"] = cell["source"].replace(source_value, target_value) | |
| cell["source"] = "# Modified for testing\n" + cell["source"] | |
| print(f"Processed {notebookfile}: {source_value} -> {target_value}") | |
| if test_device and not device_found: | |
| print(f"No device replacement found for {notebookfile}") | |
| if not found: | |
| print(f"No replacements found for {notebookfile}") | |
| disable_gradio_debug(nb, notebookfile) | |
| disable_skip_ext(nb, notebookfile, args.test_device) | |
| nb_without_out, _ = output_remover.from_notebook_node(nb) | |
| with notebookfile.with_name(f"test_{notebookfile.name}").open("w", encoding="utf-8") as out_file: | |
| out_file.write(nb_without_out) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Notebook patcher") | |
| parser.add_argument("notebooks_dir", default=".") | |
| parser.add_argument("-td", "--test_device", default="") | |
| parser.add_argument("--skip_ov_install", action="store_true") | |
| args = parser.parse_args() | |
| if not Path(args.notebooks_dir).is_dir(): | |
| raise ValueError(f"'{args.notebooks_dir}' is not an existing directory") | |
| patch_notebooks(args.notebooks_dir, args.test_device, args.skip_ov_install) | |