Spaces:
Running
Running
| import json | |
| import gradio as gr | |
| from pathlib import Path | |
| import os | |
| import pickle | |
| from constants import OBJECTS, RECEPTACLES | |
| import pandas as pd | |
| from generate_video import generate_video | |
| from utils import * | |
| def get_scene_info(scene_file_name: str) -> tuple[str, str]: | |
| """ | |
| Return: | |
| scene_image_path: str, the path of the scene image | |
| markdown_description: str, the markdown description of the scene | |
| """ | |
| scene_dir = get_scene_dir_path(scene_file_name) | |
| scene_image_path = get_scene_image_path(scene_file_name) | |
| scene_data = get_scene_data(scene_dir) | |
| all_object_type_list = [object_info['id'].split('|', 1)[0] for object_info in scene_data['objects']] | |
| room_counter = get_room_counter(scene_data) | |
| receptacle_type_list = [object_type for object_type in all_object_type_list if object_type in RECEPTACLES] | |
| object_type_list = [object_type for object_type in all_object_type_list if object_type in OBJECTS - RECEPTACLES] | |
| receptacle_counter = {receptacle_type: receptacle_type_list.count(receptacle_type) for receptacle_type in set(receptacle_type_list)} | |
| object_type_couter = {object_type: object_type_list.count(object_type) for object_type in set(object_type_list)} | |
| receptacle_type_items = list(receptacle_counter.items()) | |
| object_type_items = list(object_type_couter.items()) | |
| receptacle_type_items.sort(key=lambda x: x[1], reverse=True) | |
| object_type_items.sort(key=lambda x: x[1], reverse=True) | |
| receptacle_type_items = receptacle_type_items[:10] | |
| object_type_items = object_type_items[:10] | |
| receptacle_type_items += [('' , '')] * (10 - len(receptacle_type_items)) | |
| object_type_items += [('' , '')] * (10 - len(object_type_items)) | |
| object_type_couter_df = pd.DataFrame(object_type_items, columns=['Object Type', 'Count']) | |
| object_type_couter_df = object_type_couter_df.reset_index(drop=True) | |
| receptacle_counter_df = pd.DataFrame(receptacle_type_items, columns=['Receptacle Type', 'Count']) | |
| receptacle_counter_df = receptacle_counter_df.reset_index(drop=True) | |
| recetpace_number = len(receptacle_type_list) | |
| object_number = len(object_type_list) | |
| # room type can be like 4 Bedroom, 1 Living Room, 1 Kitchen, 1 Bathroom | |
| markdown_description = f"""Scene Information | |
| - Number of Rooms: {sum(room_counter.values())} | |
| - Room Types: {', '.join([f'{count} {room_type}' for room_type, count in room_counter.items()])} | |
| - Number of Objects: {object_number} | |
| - Number of Receptacles: {recetpace_number} | |
| """ | |
| return scene_image_path, markdown_description, receptacle_counter_df, object_type_couter_df | |
| def visualize_scene(): | |
| # 使用 Grid 布局组件来组织界面 | |
| default_scene_file_name = SCENE_FILE_NAME_LIST[0] | |
| default_scene_image_path, default_text, default_receptacle_table, default_object_table = get_scene_info(default_scene_file_name) | |
| with gr.Row(): | |
| dropdown = gr.Dropdown(choices=SCENE_FILE_NAME_LIST, label="Select Scene ID", value=default_scene_file_name) | |
| with gr.Row(equal_height=True): | |
| image = gr.Image(label="Scene Overhead View", show_label=False, value=default_scene_image_path) | |
| with gr.Column(): | |
| text = gr.Textbox(label="Scene Statistics", value=default_text, lines=6) | |
| with gr.Row(equal_height=True): | |
| receptacle_table = gr.Dataframe(label="Receptacle Type Count", height=520, value=default_receptacle_table) | |
| object_table = gr.Dataframe(label="Object Type Count", height=520, value=default_object_table) | |
| dropdown.change(fn=get_scene_info, inputs=dropdown, outputs=[image, text, receptacle_table, object_table]) | |
| return dropdown | |
| PERSON_TABLE_COLUMNS = ['name', 'age', 'gender', 'personality', 'routine', 'occupation', 'thoughts', 'lifestyle'] | |
| TEMPLATE_ROW_NUMBER = 10 | |
| PERSON_ELEMENT_NUM = 6 | |
| def get_person_info_list(person_file_list: list[str|Path|dict]): | |
| person_info_list = [] | |
| for person_file in person_file_list: | |
| if isinstance(person_file, str) or isinstance(person_file, Path): | |
| with open(person_file, 'rb') as f: | |
| person_info = pickle.load(f) | |
| person_info = person_info['persona'] | |
| if 'image' not in person_info: | |
| image_path = os.path.join(os.path.dirname(person_file), 'avatar.jpg') | |
| person_info['image'] = image_path | |
| else: | |
| person_info = person_file['persona'] | |
| person_info_list.append(person_info) | |
| return person_info_list | |
| def person_info_to_description(person_info): | |
| return f"{person_info['name']}, a {', '.join(person_info['personality'])} {person_info['age']} years old {'man' if person_info['gender'] == 'Male' else 'woman'}." | |
| def person_info_to_elements(person_info): | |
| return ( | |
| person_info['image'], | |
| f"**Name:** {person_info['name']}", | |
| f"**Age:** {person_info['age']}", | |
| f"**Gender:** {person_info['gender']}", | |
| f"**Personality:** {', '.join(person_info['personality'])}", | |
| f"**Routine:** {person_info['routine']}" | |
| ) | |
| def get_person_elements_from_row_elements(elements: list[str], row_index: int): | |
| return elements[row_index * PERSON_ELEMENT_NUM: (row_index + 1) * PERSON_ELEMENT_NUM] | |
| def get_person_name_from_row_elements(row_elements: list[str]): | |
| return row_elements[1].replace('**Name:** ', '') | |
| def get_person_dataframe(row_index_to_person_name: list[str], person_name_to_info: dict[str, dict]): | |
| if len(row_index_to_person_name) == 0: | |
| return pd.DataFrame(columns=PERSON_TABLE_COLUMNS) | |
| else: | |
| return pd.DataFrame([person_name_to_info[person_name] for person_name in row_index_to_person_name])[PERSON_TABLE_COLUMNS] | |
| def get_max_person_number(scene_dir_path: Path) -> int: | |
| scene_data = get_scene_data(scene_dir_path) | |
| room_counter = get_room_counter(scene_data) | |
| return room_counter.get('Bedroom', 0) | |
| def create_person_page(): | |
| with gr.Row(visible=False) as row: | |
| image = gr.Image(width=200, scale=0.25, show_label=False, interactive=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| name = gr.Markdown() | |
| delete_button = gr.Button("Delete",size='sm') | |
| age = gr.Markdown() | |
| gender = gr.Markdown() | |
| personality = gr.Markdown() | |
| routine = gr.Markdown() | |
| # gr.Markdown(f"**Occupation:** {person['occupation']}") | |
| # gr.Markdown(f"**Thoughts:** {person['thoughts']}") | |
| # gr.Markdown(f"**Lifestyle:** {person['lifestyle']}") | |
| return { | |
| 'row': row, | |
| 'delete_button': delete_button, | |
| 'elements': (image, name, age, gender, personality, routine), | |
| } | |
| def full_view(person_name_to_description: dict[str, str]): | |
| with gr.Blocks(): | |
| person_pages = [create_person_page() for _ in range(TEMPLATE_ROW_NUMBER)] | |
| add_person_dropdown = gr.Dropdown(choices=[person_name_to_description[key] for key in sorted(person_name_to_description.keys())], label='Add Person') | |
| add_button = gr.Button('Add') | |
| return (person_pages, add_person_dropdown, add_button) | |
| def compact_view(): | |
| return gr.Dataframe(value=pd.DataFrame(columns=PERSON_TABLE_COLUMNS), label='People Information') | |
| def add_person_to_scene( | |
| max_person_number: int, | |
| person_dropdown_description: str, person_name_to_description: dict[str, str], | |
| row_index_to_person_name: list[str], | |
| person_name_to_info: dict[str, dict] | |
| ): | |
| current_row_number = len(row_index_to_person_name) + 1 | |
| add_button_visible = current_row_number < max_person_number | |
| add_person_name = [key for key, value in person_name_to_description.items() if value == person_dropdown_description][0] | |
| add_person_elements = person_info_to_elements(person_name_to_info[add_person_name]) | |
| row_elements_update_list = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER * len(add_person_elements))] | |
| row_visible_list = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER)] | |
| row_elements_update_list[(current_row_number-1) * PERSON_ELEMENT_NUM: current_row_number * PERSON_ELEMENT_NUM] = add_person_elements | |
| row_visible_list[current_row_number-1] = gr.update(visible=True) | |
| row_index_to_person_name.append(add_person_name) | |
| to_be_chose_person_name_list = [key for key in person_name_to_description.keys() if key not in row_index_to_person_name] | |
| to_be_chose_person_description_list = [person_name_to_description[key] for key in sorted(to_be_chose_person_name_list)] | |
| return row_elements_update_list + row_visible_list + [ | |
| row_index_to_person_name, | |
| gr.update(choices=to_be_chose_person_description_list, visible=add_button_visible), | |
| gr.update(visible=add_button_visible), | |
| get_person_dataframe(row_index_to_person_name, person_name_to_info), | |
| gr.update(visible=True) | |
| ] | |
| def delete_person_from_scene( | |
| person_name_to_info: dict[str, dict], | |
| row_index_to_person_name: list[int], | |
| button_row_index: int, | |
| person_name_to_description: dict[str, str], | |
| *row_elements: tuple[str] | |
| ): | |
| current_row_number = len(row_index_to_person_name) | |
| row_elements = list(row_elements) | |
| row_elements[button_row_index * PERSON_ELEMENT_NUM: (current_row_number-1) * PERSON_ELEMENT_NUM] = row_elements[(button_row_index + 1) * PERSON_ELEMENT_NUM: current_row_number * PERSON_ELEMENT_NUM] | |
| person_row_visible = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER)] | |
| person_row_visible[current_row_number-1] = gr.update(visible=False) | |
| row_index_to_person_name.pop(button_row_index) | |
| to_be_chose_person_name_list = [key for key in person_name_to_description.keys() if key not in row_index_to_person_name] | |
| to_be_chose_person_description_list = [person_name_to_description[key] for key in sorted(to_be_chose_person_name_list)] | |
| return row_elements + person_row_visible + [ | |
| gr.update(visible=True), | |
| gr.update(choices=to_be_chose_person_description_list, visible=True), | |
| row_index_to_person_name, | |
| get_person_dataframe(row_index_to_person_name, person_name_to_info), | |
| gr.update(visible=len(row_index_to_person_name) > 0) | |
| ] | |
| def reset_person_rows(person_name_to_description: dict[str, str]): | |
| rows_visible = [gr.update(visible=False) for _ in range(TEMPLATE_ROW_NUMBER)] | |
| row_index_to_person_name = [] | |
| add_button_visible = gr.update(visible=True) | |
| add_person_dropdown = gr.update(choices=[person_name_to_description[key] for key in sorted(person_name_to_description.keys())], visible=True) | |
| person_dataframe = pd.DataFrame(columns=PERSON_TABLE_COLUMNS) | |
| return rows_visible + [row_index_to_person_name, add_button_visible, add_person_dropdown, person_dataframe] | |
| def reset_person_page( | |
| scene_file_name: str, | |
| person_name_to_description: dict[str, str], | |
| ): | |
| # get the bedroom number of the scene | |
| scene_dir = SCENE_ROOT_DIR / scene_file_name | |
| max_added_person = get_max_person_number(scene_dir) | |
| return reset_person_rows(person_name_to_description) + [max_added_person, gr.update(visible=False), gr.update(visible=False)] | |
| def visualize_person(person_name_to_info: gr.State, person_name_to_description: gr.State, max_added_person: int): | |
| row_index_to_person_name = gr.State([]) | |
| max_added_person = gr.State(max_added_person) | |
| with gr.Blocks(): | |
| with gr.Tab(label='Full View'): | |
| person_pages, add_person_dropdown, add_button = full_view(person_name_to_description.value) | |
| with gr.Tab(label='Compact View'): | |
| person_dataframe = compact_view() | |
| generate_button = gr.Button("Generate Video", visible=False) | |
| row_elements = [element for page in person_pages for element in page['elements']] | |
| rows = [page['row'] for page in person_pages] | |
| delete_buttons: list[gr.Button] = [page['delete_button'] for page in person_pages] | |
| add_button.click( | |
| fn=add_person_to_scene, | |
| inputs=[max_added_person, add_person_dropdown, person_name_to_description, row_index_to_person_name, person_name_to_info], | |
| outputs=row_elements + rows + [row_index_to_person_name, add_person_dropdown, add_button, person_dataframe, generate_button] | |
| ) | |
| for i, delete_button in enumerate(delete_buttons): | |
| delete_button.click( | |
| fn=delete_person_from_scene, | |
| inputs=[person_name_to_info, row_index_to_person_name, gr.State(i), person_name_to_description] + row_elements, | |
| outputs=row_elements + rows + [add_button, add_person_dropdown, row_index_to_person_name, person_dataframe, generate_button] | |
| ) | |
| return row_index_to_person_name, max_added_person, row_elements, rows, delete_buttons, add_button, add_person_dropdown, person_dataframe, generate_button | |
| def generate_button_click_change_state(): | |
| return gr.update(visible=False), gr.update(visible=True) | |
| def generate_button_click_change_video( | |
| scene_file_name: str, | |
| row_index_to_person_name: list[str], | |
| person_name_to_file: dict[str, dict[str, any]] | |
| ): | |
| video_path = generate_video(scene_file_name, [person_name_to_file[person_name] for person_name in row_index_to_person_name]) | |
| if video_path: | |
| return str(video_path) | |
| return gr.update(visible=False) | |
| def visualize_dynamic_generate(person_name_to_file, person_name_to_info, person_name_to_description): | |
| gr.Markdown("## Scene Information") | |
| scene_dropdown = visualize_scene() | |
| max_person_number = get_max_person_number(SCENE_ROOT_DIR / scene_dropdown.value) | |
| gr.Markdown("## Person Information") | |
| row_index_to_person_name, max_person_number, _, rows, _, add_button, add_person_dropdown, person_dataframe, generate_button \ | |
| = visualize_person(person_name_to_info, person_name_to_description, max_person_number) | |
| video = gr.Video(visible=False) | |
| clear_button = gr.Button("Clear") | |
| generate_button.click( | |
| fn=generate_button_click_change_video, | |
| inputs=[scene_dropdown, row_index_to_person_name, person_name_to_file], | |
| outputs=video | |
| ) | |
| generate_button.click( | |
| fn=generate_button_click_change_state, | |
| outputs=[generate_button, video] | |
| ) | |
| reset_inputs = [scene_dropdown, person_name_to_description] | |
| reset_outputs = rows + [row_index_to_person_name, add_button, add_person_dropdown, person_dataframe, max_person_number, generate_button, video] | |
| scene_dropdown.change( | |
| fn=reset_person_page, | |
| inputs=reset_inputs, | |
| outputs=reset_outputs | |
| ) | |
| clear_button.click( | |
| fn=reset_person_page, | |
| inputs=reset_inputs, | |
| outputs=reset_outputs | |
| ) | |
| return reset_inputs, reset_outputs | |
| if __name__ == '__main__': | |
| with gr.Blocks() as demo: | |
| visualize_dynamic_generate() | |
| demo.launch() | |