Spaces:
Runtime error
Runtime error
fix: add timeout to prevent unexcepted tasl
Browse files- app.py +3 -0
- modules/model.py +11 -0
app.py
CHANGED
|
@@ -60,6 +60,7 @@ samplers_k_diffusion = [
|
|
| 60 |
# ]
|
| 61 |
|
| 62 |
start_time = time.time()
|
|
|
|
| 63 |
|
| 64 |
scheduler = DDIMScheduler.from_pretrained(
|
| 65 |
base_model,
|
|
@@ -257,6 +258,8 @@ def inference(
|
|
| 257 |
"sampler_opt": sampler_opt,
|
| 258 |
"pww_state": state,
|
| 259 |
"pww_attn_weight": g_strength,
|
|
|
|
|
|
|
| 260 |
}
|
| 261 |
|
| 262 |
if img_input is not None:
|
|
|
|
| 60 |
# ]
|
| 61 |
|
| 62 |
start_time = time.time()
|
| 63 |
+
timeout = 120
|
| 64 |
|
| 65 |
scheduler = DDIMScheduler.from_pretrained(
|
| 66 |
base_model,
|
|
|
|
| 258 |
"sampler_opt": sampler_opt,
|
| 259 |
"pww_state": state,
|
| 260 |
"pww_attn_weight": g_strength,
|
| 261 |
+
"start_time": start_time,
|
| 262 |
+
"timeout": timeout,
|
| 263 |
}
|
| 264 |
|
| 265 |
if img_input is not None:
|
modules/model.py
CHANGED
|
@@ -6,6 +6,7 @@ import re
|
|
| 6 |
from collections import defaultdict
|
| 7 |
from typing import List, Optional, Union
|
| 8 |
|
|
|
|
| 9 |
import k_diffusion
|
| 10 |
import numpy as np
|
| 11 |
import PIL
|
|
@@ -446,6 +447,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 446 |
pww_attn_weight=1.0,
|
| 447 |
sampler_name="",
|
| 448 |
sampler_opt={},
|
|
|
|
|
|
|
| 449 |
scale_ratio=8.0,
|
| 450 |
):
|
| 451 |
sampler = self.get_scheduler(sampler_name)
|
|
@@ -504,6 +507,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 504 |
|
| 505 |
def model_fn(x, sigma):
|
| 506 |
|
|
|
|
|
|
|
|
|
|
| 507 |
latent_model_input = torch.cat([x] * 2)
|
| 508 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
| 509 |
encoder_state = {
|
|
@@ -617,6 +623,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 617 |
pww_attn_weight=1.0,
|
| 618 |
sampler_name="",
|
| 619 |
sampler_opt={},
|
|
|
|
|
|
|
| 620 |
):
|
| 621 |
sampler = self.get_scheduler(sampler_name)
|
| 622 |
# 1. Check inputs. Raise error if not correct
|
|
@@ -667,6 +675,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
| 667 |
|
| 668 |
def model_fn(x, sigma):
|
| 669 |
|
|
|
|
|
|
|
|
|
|
| 670 |
latent_model_input = torch.cat([x] * 2)
|
| 671 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
| 672 |
encoder_state = {
|
|
|
|
| 6 |
from collections import defaultdict
|
| 7 |
from typing import List, Optional, Union
|
| 8 |
|
| 9 |
+
import time
|
| 10 |
import k_diffusion
|
| 11 |
import numpy as np
|
| 12 |
import PIL
|
|
|
|
| 447 |
pww_attn_weight=1.0,
|
| 448 |
sampler_name="",
|
| 449 |
sampler_opt={},
|
| 450 |
+
start_time=-1,
|
| 451 |
+
timeout=180,
|
| 452 |
scale_ratio=8.0,
|
| 453 |
):
|
| 454 |
sampler = self.get_scheduler(sampler_name)
|
|
|
|
| 507 |
|
| 508 |
def model_fn(x, sigma):
|
| 509 |
|
| 510 |
+
if start_time > 0 and timeout > 0:
|
| 511 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
| 512 |
+
|
| 513 |
latent_model_input = torch.cat([x] * 2)
|
| 514 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
| 515 |
encoder_state = {
|
|
|
|
| 623 |
pww_attn_weight=1.0,
|
| 624 |
sampler_name="",
|
| 625 |
sampler_opt={},
|
| 626 |
+
start_time=-1,
|
| 627 |
+
timeout=180,
|
| 628 |
):
|
| 629 |
sampler = self.get_scheduler(sampler_name)
|
| 630 |
# 1. Check inputs. Raise error if not correct
|
|
|
|
| 675 |
|
| 676 |
def model_fn(x, sigma):
|
| 677 |
|
| 678 |
+
if start_time > 0 and timeout > 0:
|
| 679 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
| 680 |
+
|
| 681 |
latent_model_input = torch.cat([x] * 2)
|
| 682 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
| 683 |
encoder_state = {
|