Spaces:
Runtime error
Runtime error
Update syngen_diffusion_pipeline.py
Browse filesmaking syngen a little more efficient
- syngen_diffusion_pipeline.py +71 -23
syngen_diffusion_pipeline.py
CHANGED
|
@@ -19,8 +19,6 @@ from diffusers.utils import (
|
|
| 19 |
)
|
| 20 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
| 21 |
|
| 22 |
-
from compute_loss import get_attention_map_index_to_wordpiece, split_indices, calculate_positive_loss, calculate_negative_loss, get_indices, start_token, end_token, \
|
| 23 |
-
align_wordpieces_indices, extract_attribution_indices
|
| 24 |
|
| 25 |
logger = logging.get_logger(__name__)
|
| 26 |
|
|
@@ -40,6 +38,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 40 |
requires_safety_checker)
|
| 41 |
|
| 42 |
self.parser = spacy.load("en_core_web_trf")
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def _aggregate_and_get_attention_maps_per_token(self):
|
| 45 |
attention_maps = self.attention_store.aggregate_attention(
|
|
@@ -105,6 +106,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 105 |
callback_steps: int = 1,
|
| 106 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 107 |
syngen_step_size: float = 20.0,
|
|
|
|
| 108 |
):
|
| 109 |
r"""
|
| 110 |
Function invoked when calling the pipeline for generation.
|
|
@@ -165,7 +167,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 165 |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
| 166 |
`self.processor` in
|
| 167 |
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
| 168 |
-
syngen_step_size (`
|
| 169 |
Controls the step size of each SynGen update.
|
| 170 |
|
| 171 |
Examples:
|
|
@@ -177,6 +179,11 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 177 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 178 |
(nsfw) content, according to the `safety_checker`.
|
| 179 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
# 0. Default height and width to unet
|
| 181 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 182 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
@@ -234,7 +241,7 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 234 |
latents,
|
| 235 |
)
|
| 236 |
|
| 237 |
-
# 6. Prepare extra step kwargs.
|
| 238 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 239 |
|
| 240 |
# NEW - stores the attention calculated in the unet
|
|
@@ -251,16 +258,17 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 251 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 252 |
for i, t in enumerate(timesteps):
|
| 253 |
# NEW
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
| 264 |
|
| 265 |
# expand the latents if we are doing classifier free guidance
|
| 266 |
latent_model_input = (
|
|
@@ -325,6 +333,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 325 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 326 |
self.final_offload_hook.offload()
|
| 327 |
|
|
|
|
|
|
|
|
|
|
| 328 |
if not return_dict:
|
| 329 |
return (image, has_nsfw_concept)
|
| 330 |
|
|
@@ -332,6 +343,8 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 332 |
images=image, nsfw_content_detected=has_nsfw_concept
|
| 333 |
)
|
| 334 |
|
|
|
|
|
|
|
| 335 |
def _syngen_step(
|
| 336 |
self,
|
| 337 |
latents,
|
|
@@ -358,12 +371,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 358 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 359 |
).sample
|
| 360 |
self.unet.zero_grad()
|
| 361 |
-
|
| 362 |
# Get attention maps
|
| 363 |
attention_maps = self._aggregate_and_get_attention_maps_per_token()
|
| 364 |
-
|
| 365 |
loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
|
| 366 |
-
|
| 367 |
# Perform gradient update
|
| 368 |
if i < max_iter_to_alter:
|
| 369 |
if loss != 0:
|
|
@@ -393,7 +403,9 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 393 |
prompt: Union[str, List[str]],
|
| 394 |
attn_map_idx_to_wp,
|
| 395 |
) -> torch.Tensor:
|
| 396 |
-
|
|
|
|
|
|
|
| 397 |
loss = 0
|
| 398 |
|
| 399 |
for subtree_indices in subtrees_indices:
|
|
@@ -474,15 +486,24 @@ class SynGenDiffusionPipeline(StableDiffusionPipeline):
|
|
| 474 |
collected_spacy_indices.add(collected_idx)
|
| 475 |
|
| 476 |
paired_indices.append(curr_collected_wp_indices)
|
| 477 |
-
|
| 478 |
return paired_indices
|
| 479 |
|
|
|
|
| 480 |
def _extract_attribution_indices(self, prompt):
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
def _get_attention_maps_list(
|
| 488 |
attention_maps: torch.Tensor
|
|
@@ -492,4 +513,31 @@ def _get_attention_maps_list(
|
|
| 492 |
attention_maps[:, :, i] for i in range(attention_maps.shape[2])
|
| 493 |
]
|
| 494 |
|
| 495 |
-
return attention_maps_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
|
|
|
| 38 |
requires_safety_checker)
|
| 39 |
|
| 40 |
self.parser = spacy.load("en_core_web_trf")
|
| 41 |
+
self.subtrees_indices = None
|
| 42 |
+
self.doc = None
|
| 43 |
+
# self.doc = ""#self.parser(prompt)
|
| 44 |
|
| 45 |
def _aggregate_and_get_attention_maps_per_token(self):
|
| 46 |
attention_maps = self.attention_store.aggregate_attention(
|
|
|
|
| 106 |
callback_steps: int = 1,
|
| 107 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 108 |
syngen_step_size: float = 20.0,
|
| 109 |
+
parsed_prompt: str=None
|
| 110 |
):
|
| 111 |
r"""
|
| 112 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 167 |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
| 168 |
`self.processor` in
|
| 169 |
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
| 170 |
+
syngen_step_size (`float`, *optional*, default to 20.0):
|
| 171 |
Controls the step size of each SynGen update.
|
| 172 |
|
| 173 |
Examples:
|
|
|
|
| 179 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 180 |
(nsfw) content, according to the `safety_checker`.
|
| 181 |
"""
|
| 182 |
+
|
| 183 |
+
if parsed_prompt:
|
| 184 |
+
self.doc = parsed_prompt
|
| 185 |
+
else:
|
| 186 |
+
self.doc = self.parser(prompt)
|
| 187 |
# 0. Default height and width to unet
|
| 188 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 189 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
| 241 |
latents,
|
| 242 |
)
|
| 243 |
|
| 244 |
+
# 6. Prepare extra step kwargs.
|
| 245 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 246 |
|
| 247 |
# NEW - stores the attention calculated in the unet
|
|
|
|
| 258 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 259 |
for i, t in enumerate(timesteps):
|
| 260 |
# NEW
|
| 261 |
+
if i < 25:
|
| 262 |
+
latents = self._syngen_step(
|
| 263 |
+
latents,
|
| 264 |
+
text_embeddings,
|
| 265 |
+
t,
|
| 266 |
+
i,
|
| 267 |
+
syngen_step_size,
|
| 268 |
+
cross_attention_kwargs,
|
| 269 |
+
prompt,
|
| 270 |
+
max_iter_to_alter=25,
|
| 271 |
+
)
|
| 272 |
|
| 273 |
# expand the latents if we are doing classifier free guidance
|
| 274 |
latent_model_input = (
|
|
|
|
| 333 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 334 |
self.final_offload_hook.offload()
|
| 335 |
|
| 336 |
+
self.doc = None
|
| 337 |
+
self.subtrees_indices = None
|
| 338 |
+
|
| 339 |
if not return_dict:
|
| 340 |
return (image, has_nsfw_concept)
|
| 341 |
|
|
|
|
| 343 |
images=image, nsfw_content_detected=has_nsfw_concept
|
| 344 |
)
|
| 345 |
|
| 346 |
+
|
| 347 |
+
|
| 348 |
def _syngen_step(
|
| 349 |
self,
|
| 350 |
latents,
|
|
|
|
| 371 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 372 |
).sample
|
| 373 |
self.unet.zero_grad()
|
|
|
|
| 374 |
# Get attention maps
|
| 375 |
attention_maps = self._aggregate_and_get_attention_maps_per_token()
|
|
|
|
| 376 |
loss = self._compute_loss(attention_maps=attention_maps, prompt=prompt)
|
|
|
|
| 377 |
# Perform gradient update
|
| 378 |
if i < max_iter_to_alter:
|
| 379 |
if loss != 0:
|
|
|
|
| 403 |
prompt: Union[str, List[str]],
|
| 404 |
attn_map_idx_to_wp,
|
| 405 |
) -> torch.Tensor:
|
| 406 |
+
if not self.subtrees_indices:
|
| 407 |
+
self.subtrees_indices = self._extract_attribution_indices(prompt)
|
| 408 |
+
subtrees_indices = self.subtrees_indices
|
| 409 |
loss = 0
|
| 410 |
|
| 411 |
for subtree_indices in subtrees_indices:
|
|
|
|
| 486 |
collected_spacy_indices.add(collected_idx)
|
| 487 |
|
| 488 |
paired_indices.append(curr_collected_wp_indices)
|
| 489 |
+
|
| 490 |
return paired_indices
|
| 491 |
|
| 492 |
+
|
| 493 |
def _extract_attribution_indices(self, prompt):
|
| 494 |
+
# extract standard attribution indices
|
| 495 |
+
pairs = extract_attribution_indices(self.doc)
|
| 496 |
+
|
| 497 |
+
# extract attribution indices with verbs in between
|
| 498 |
+
pairs_2 = extract_attribution_indices_with_verb_root(self.doc)
|
| 499 |
+
pairs_3 = extract_attribution_indices_with_verbs(self.doc)
|
| 500 |
+
# make sure there are no duplicates
|
| 501 |
+
pairs = unify_lists(pairs, pairs_2, pairs_3)
|
| 502 |
|
| 503 |
|
| 504 |
+
print(f"Final pairs collected: {pairs}")
|
| 505 |
+
paired_indices = self._align_indices(prompt, pairs)
|
| 506 |
+
return paired_indices
|
| 507 |
|
| 508 |
def _get_attention_maps_list(
|
| 509 |
attention_maps: torch.Tensor
|
|
|
|
| 513 |
attention_maps[:, :, i] for i in range(attention_maps.shape[2])
|
| 514 |
]
|
| 515 |
|
| 516 |
+
return attention_maps_list
|
| 517 |
+
|
| 518 |
+
def is_sublist(sub, main):
|
| 519 |
+
# This function checks if 'sub' is a sublist of 'main'
|
| 520 |
+
return len(sub) < len(main) and all(item in main for item in sub)
|
| 521 |
+
|
| 522 |
+
def unify_lists(lists_1, lists_2, lists_3):
|
| 523 |
+
unified_list = lists_1 + lists_2 + lists_3
|
| 524 |
+
sorted_list = sorted(unified_list, key=len)
|
| 525 |
+
seen = set()
|
| 526 |
+
|
| 527 |
+
result = []
|
| 528 |
+
|
| 529 |
+
for i in range(len(sorted_list)):
|
| 530 |
+
if tuple(sorted_list[i]) in seen: # Skip if already added
|
| 531 |
+
continue
|
| 532 |
+
|
| 533 |
+
sublist_to_add = True
|
| 534 |
+
for j in range(i + 1, len(sorted_list)):
|
| 535 |
+
if is_sublist(sorted_list[i], sorted_list[j]):
|
| 536 |
+
sublist_to_add = False
|
| 537 |
+
break
|
| 538 |
+
|
| 539 |
+
if sublist_to_add:
|
| 540 |
+
result.append(sorted_list[i])
|
| 541 |
+
seen.add(tuple(sorted_list[i]))
|
| 542 |
+
|
| 543 |
+
return result
|