dn6 HF Staff commited on
Commit
f6ff21d
·
verified ·
1 Parent(s): 51b82a1

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +91 -0
  2. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+
4
+ import gradio as gr
5
+ from diffusers import ModularPipelineBlocks
6
+ from diffusers.utils import export_to_video, load_image
7
+ from diffusers.modular_pipelines import WanModularPipeline
8
+
9
+ class MatrixGameWanModularPipeline(WanModularPipeline):
10
+ """
11
+ A ModularPipeline for MatrixGameWan.
12
+
13
+ <Tip warning={true}>
14
+
15
+ This is an experimental feature and is likely to change in the future.
16
+
17
+ </Tip>
18
+ """
19
+
20
+ @property
21
+ def default_sample_height(self):
22
+ return 44
23
+
24
+ @property
25
+ def default_sample_width(self):
26
+ return 80
27
+
28
+
29
+ blocks = ModularPipelineBlocks.from_pretrained("diffusers/matrix-game-2-modular", trust_remote_code=True)
30
+ image_to_action_block = ModularPipelineBlocks.from_pretrained("dn6/matrix-game-image-to-action", trust_remote_code=True)
31
+
32
+ blocks.sub_blocks.insert("image_to_action", image_to_action_block, 0)
33
+
34
+ pipe = MatrixGameWanModularPipeline(blocks, "diffusers-internal-dev/matrix-game-2-modular")
35
+ pipe.load_components(trust_remote_code=True, device_map="cuda", torch_dtype={"default": torch.bfloat16, "vae": torch.float32})
36
+
37
+ @spaces.GPU(300)
38
+ def predict(image, prompt):
39
+ output = pipe(image=image, prompt=prompt, num_frames=141)
40
+ return export_to_video(output.values['videos'][0], "output.mp4")
41
+
42
+
43
+ examples = []
44
+
45
+ css = """
46
+ #col-container {
47
+ margin: 0 auto;
48
+ max-width: 1024px;
49
+ }
50
+ #logo-title {
51
+ text-align: center;
52
+ }
53
+ #logo-title img {
54
+ width: 400px;
55
+ }
56
+ #edit_text{margin-top: -62px !important}
57
+ """
58
+
59
+ with gr.Blocks(css=css) as demo:
60
+ with gr.Column(elem_id="col-container"):
61
+ with gr.Row():
62
+ with gr.Column():
63
+ input_images = gr.Gallery(label="Input Images",
64
+ show_label=False,
65
+ type="pil",
66
+ interactive=True)
67
+
68
+ with gr.Column():
69
+ result = gr.Gallery(label="Result", show_label=False, type="pil")
70
+
71
+ with gr.Row():
72
+ prompt = gr.Text(
73
+ label="Prompt",
74
+ show_label=False,
75
+ placeholder="describe the edit instruction",
76
+ container=False,
77
+ )
78
+ run_button = gr.Button("Run!", variant="primary")
79
+
80
+ gr.on(
81
+ triggers=[run_button.click, prompt.submit],
82
+ fn=predict,
83
+ inputs=[
84
+ input_images,
85
+ prompt,
86
+ ],
87
+ outputs=[result], # Added use_output_btn to outputs
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.10.1
2
+ einops==0.8.1
3
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
4
+ hf-transfer==0.1.9
5
+ hf-xet==1.1.8
6
+ huggingface-hub==0.34.4
7
+ imageio==2.37.0
8
+ imageio-ffmpeg==0.6.0
9
+ safetensors==0.6.2
10
+ sentencepiece==0.2.1
11
+ torch==2.7.0
12
+ torchao==0.12.0
13
+ torchvision==0.22.0
14
+ transformers==4.55.4