mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove callback-generator wrapper
This commit is contained in:
parent
957ee6d370
commit
3d8da67be3
@ -5,7 +5,7 @@ import inspect
|
||||
import math
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .diffusion import (
|
||||
@ -161,33 +160,6 @@ def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
|
||||
CallbackType = TypeVar("CallbackType")
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
ParamType = ParamSpec("ParamType")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
"""Convert a generator to a function with a callback and a return value."""
|
||||
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: ParamType.args,
|
||||
callback: Callable[[CallbackType], Any] = None,
|
||||
**kwargs: ParamType.kwargs,
|
||||
) -> ReturnType:
|
||||
result = None
|
||||
for result in self.generator_method(*args, **kwargs):
|
||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
@ -375,10 +347,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents, None
|
||||
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
)
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -417,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
@ -428,13 +396,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
latents = result.latents
|
||||
|
||||
# restore unmasked part
|
||||
if mask is not None:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents, result.attention_map_saver
|
||||
return latents, attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
@ -444,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
@ -461,13 +428,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
if callback is not None:
|
||||
callback(PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
))
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@ -500,7 +468,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(
|
||||
if callback is not None:
|
||||
callback(PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
@ -508,7 +477,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
))
|
||||
|
||||
return latents, attention_map_saver
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user