Remove callback-generator wrapper

This commit is contained in:
Sergey Borisov 2023-08-14 03:35:15 +03:00
parent 957ee6d370
commit 3d8da67be3

View File

@ -5,7 +5,7 @@ import inspect
import math
import secrets
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Generic, List, Optional, Type, Union
import PIL.Image
import einops
@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
from .diffusion import (
@ -161,33 +160,6 @@ def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9
CallbackType = TypeVar("CallbackType")
ReturnType = TypeVar("ReturnType")
ParamType = ParamSpec("ParamType")
@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
"""Convert a generator to a function with a callback and a return value."""
generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType]
def __call__(
self,
*args: ParamType.args,
callback: Callable[[CallbackType], Any] = None,
**kwargs: ParamType.kwargs,
) -> ReturnType:
result = None
for result in self.generator_method(*args, **kwargs):
if callback is not None and isinstance(result, self.callback_arg_type):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
@ -375,10 +347,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_timestep.shape[0] == 0:
return latents, None
infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
)
if additional_guidance is None:
additional_guidance = []
@ -417,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
try:
result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, attention_map_saver = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
@ -428,13 +396,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
latents = result.latents
# restore unmasked part
if mask is not None:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents, result.attention_map_saver
return latents, attention_map_saver
def generate_latents_from_embeddings(
self,
@ -444,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
@ -461,13 +428,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
if callback is not None:
callback(PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
))
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
@ -500,15 +468,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
)
if callback is not None:
callback(PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
))
return latents, attention_map_saver