Added type annotations for pipeline init args
#3
by
guiyrt
- opened
- pipeline.py +2 -2
pipeline.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
|
| 21 |
-
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
| 22 |
|
| 23 |
|
| 24 |
class CustomPipeline(DiffusionPipeline):
|
|
@@ -33,7 +33,7 @@ class CustomPipeline(DiffusionPipeline):
|
|
| 33 |
[`DDPMScheduler`], or [`DDIMScheduler`].
|
| 34 |
"""
|
| 35 |
|
| 36 |
-
def __init__(self, unet, scheduler):
|
| 37 |
super().__init__()
|
| 38 |
self.register_modules(unet=unet, scheduler=scheduler)
|
| 39 |
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
|
| 21 |
+
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
|
| 22 |
|
| 23 |
|
| 24 |
class CustomPipeline(DiffusionPipeline):
|
|
|
|
| 33 |
[`DDPMScheduler`], or [`DDIMScheduler`].
|
| 34 |
"""
|
| 35 |
|
| 36 |
+
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
| 37 |
super().__init__()
|
| 38 |
self.register_modules(unet=unet, scheduler=scheduler)
|
| 39 |
|