mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove callback-generator wrapper
This commit is contained in:
parent
957ee6d370
commit
3d8da67be3
@ -5,7 +5,7 @@ import inspect
|
|||||||
import math
|
import math
|
||||||
import secrets
|
import secrets
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, Union
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import einops
|
import einops
|
||||||
@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
@ -161,33 +160,6 @@ def is_inpainting_model(unet: UNet2DConditionModel):
|
|||||||
return unet.conv_in.in_channels == 9
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
|
|
||||||
CallbackType = TypeVar("CallbackType")
|
|
||||||
ReturnType = TypeVar("ReturnType")
|
|
||||||
ParamType = ParamSpec("ParamType")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|
||||||
"""Convert a generator to a function with a callback and a return value."""
|
|
||||||
|
|
||||||
generator_method: Callable[ParamType, ReturnType]
|
|
||||||
callback_arg_type: Type[CallbackType]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
*args: ParamType.args,
|
|
||||||
callback: Callable[[CallbackType], Any] = None,
|
|
||||||
**kwargs: ParamType.kwargs,
|
|
||||||
) -> ReturnType:
|
|
||||||
result = None
|
|
||||||
for result in self.generator_method(*args, **kwargs):
|
|
||||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
|
||||||
callback(result)
|
|
||||||
if result is None:
|
|
||||||
raise AssertionError("why was that an empty generator?")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ControlNetData:
|
class ControlNetData:
|
||||||
model: ControlNetModel = Field(default=None)
|
model: ControlNetModel = Field(default=None)
|
||||||
@ -375,10 +347,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents, None
|
return latents, None
|
||||||
|
|
||||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
|
||||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
|
||||||
)
|
|
||||||
|
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
@ -417,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -428,13 +396,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
finally:
|
finally:
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||||
|
|
||||||
latents = result.latents
|
|
||||||
|
|
||||||
# restore unmasked part
|
# restore unmasked part
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||||
|
|
||||||
return latents, result.attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
def generate_latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -444,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
@ -461,13 +428,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
yield PipelineIntermediateState(
|
if callback is not None:
|
||||||
|
callback(PipelineIntermediateState(
|
||||||
step=-1,
|
step=-1,
|
||||||
order=self.scheduler.order,
|
order=self.scheduler.order,
|
||||||
total_steps=len(timesteps),
|
total_steps=len(timesteps),
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
))
|
||||||
|
|
||||||
# print("timesteps:", timesteps)
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
@ -500,7 +468,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
yield PipelineIntermediateState(
|
if callback is not None:
|
||||||
|
callback(PipelineIntermediateState(
|
||||||
step=i,
|
step=i,
|
||||||
order=self.scheduler.order,
|
order=self.scheduler.order,
|
||||||
total_steps=len(timesteps),
|
total_steps=len(timesteps),
|
||||||
@ -508,7 +477,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents=latents,
|
latents=latents,
|
||||||
predicted_original=predicted_original,
|
predicted_original=predicted_original,
|
||||||
attention_map_saver=attention_map_saver,
|
attention_map_saver=attention_map_saver,
|
||||||
)
|
))
|
||||||
|
|
||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user