mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
8 Commits
v4.2.9.dev
...
feat/contr
Author | SHA1 | Date | |
---|---|---|---|
448f6a04f4 | |||
cf6941f665 | |||
6bd74de8f1 | |||
54c8d542dc | |||
75c2df3016 | |||
8ac8be44a2 | |||
5ab2164bdc | |||
5b11bcdfb8 |
@ -4,7 +4,9 @@ from functools import partial
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import ControlNetModel
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -53,6 +55,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
|
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||||
|
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||||
|
# control_strength: Optional[float] = Field(default=1.0, ge=0, le=1, description="The strength of the controlnet")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
@ -70,20 +75,36 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Handle invalid model parameter
|
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
|
control_image = (
|
||||||
|
None if self.control_image is None
|
||||||
|
else context.services.images.get(
|
||||||
|
self.control_image.image_type, self.control_image.image_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# loading controlnet model
|
||||||
|
if (self.control_model is None or self.control_model==''):
|
||||||
|
control_model = None
|
||||||
|
else:
|
||||||
|
# FIXME: change this to dropdown menu?
|
||||||
|
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||||
|
torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id
|
context.graph_execution_state_id
|
||||||
)
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Txt2Img(model).generate(
|
txt2img = Txt2Img(model, control_model=control_model)
|
||||||
|
outputs = txt2img.generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
|
control_image=control_image,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt", "control_image" }
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
@ -86,9 +86,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model_info=model_info
|
self.model_info=model_info
|
||||||
self.params=params
|
self.params=params
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
@ -129,9 +131,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
model=model,
|
model=model,
|
||||||
scheduler_name=generator_args.get('scheduler')
|
scheduler_name=generator_args.get('scheduler')
|
||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
|
||||||
|
# get conditioning from prompt via Compel package
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||||
|
|
||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get('variation_amount'),
|
||||||
@ -281,7 +286,7 @@ class Generator:
|
|||||||
precision: str
|
precision: str
|
||||||
model: DiffusionPipeline
|
model: DiffusionPipeline
|
||||||
|
|
||||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
@ -354,7 +359,6 @@ class Generator:
|
|||||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||||
first_seed = seed
|
first_seed = seed
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
|
|
||||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||||
with scope(self.model.device.type):
|
with scope(self.model.device.type):
|
||||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
@ -13,8 +17,13 @@ from .base import Generator
|
|||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision,
|
||||||
super().__init__(model, precision)
|
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||||
|
**kwargs):
|
||||||
|
self.control_model = control_model
|
||||||
|
if isinstance(self.control_model, list):
|
||||||
|
self.control_model = MultiControlNetModel(self.control_model)
|
||||||
|
super().__init__(model, precision, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
control_image = kwargs.get("control_image", None)
|
||||||
|
do_classifier_free_guidance = cfg_scale > 1.0
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.control_model = self.control_model
|
||||||
pipeline.scheduler = sampler
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
if control_image is not None:
|
||||||
|
if isinstance(self.control_model, ControlNetModel):
|
||||||
|
control_image = pipeline.prepare_control_image(
|
||||||
|
image=control_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
elif isinstance(self.control_model, MultiControlNetModel):
|
||||||
|
images = []
|
||||||
|
for image_ in control_image:
|
||||||
|
image_ = self.model.prepare_control_image(
|
||||||
|
image=image_,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
images.append(image_)
|
||||||
|
control_image = images
|
||||||
|
kwargs["control_image"] = control_image
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
|||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -9,16 +9,20 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from compel import EmbeddingsProvider
|
from compel import EmbeddingsProvider
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
@ -27,6 +31,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
@ -304,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
|
control_model: ControlNetModel = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -324,6 +330,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
# FIXME: can't currently register control module
|
||||||
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward, is_running_diffusers=True
|
||||||
@ -343,6 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -464,6 +473,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
|
**kwargs,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -484,6 +494,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -508,6 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
self.scheduler.set_timesteps(
|
self.scheduler.set_timesteps(
|
||||||
@ -525,6 +537,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -537,6 +550,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
@ -575,6 +589,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -615,6 +630,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -626,6 +642,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||||
|
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||||
|
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
|
||||||
|
# handling case where using multiple control models but only specifying single control_scale
|
||||||
|
# so reshape control_scale to match number of control models
|
||||||
|
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
|
||||||
|
control_scale = [control_scale] * len(self.control_model.nets)
|
||||||
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||||
|
else:
|
||||||
|
latent_control_input = latent_model_input
|
||||||
|
# controlnet inference
|
||||||
|
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||||
|
latent_control_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings]),
|
||||||
|
controlnet_cond=control_image,
|
||||||
|
conditioning_scale=control_scale,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
@ -635,6 +678,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
|
down_block_additional_residuals=down_block_res_samples,
|
||||||
|
mid_block_additional_residual=mid_block_res_sample,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -656,6 +701,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
@ -675,7 +721,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
def img2img_from_embeddings(
|
||||||
@ -937,3 +984,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
debug_image(
|
debug_image(
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||||
|
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||||
|
def prepare_control_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
|
batch_size=1,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
):
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
images = []
|
||||||
|
for image_ in image:
|
||||||
|
image_ = image_.convert("RGB")
|
||||||
|
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
image_ = np.array(image_)
|
||||||
|
image_ = image_[None, :]
|
||||||
|
images.append(image_)
|
||||||
|
image = images
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
return image
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
@ -168,6 +168,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int] = None,
|
step_index: Optional[int] = None,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -196,7 +197,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
@ -208,13 +209,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -222,7 +224,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
@ -269,13 +271,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
@ -289,16 +291,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning: torch.Tensor,
|
unconditioning: torch.Tensor,
|
||||||
conditioning: torch.Tensor,
|
conditioning: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -313,7 +316,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -324,6 +327,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -332,6 +336,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||||
@ -340,6 +345,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -349,6 +355,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -364,6 +371,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
@ -375,6 +383,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -385,6 +394,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
@ -398,13 +408,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
context.request_save_attention_maps(ca_type)
|
context.request_save_attention_maps(ca_type)
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||||
context.clear_requests(cleanup=False)
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
@ -415,7 +425,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
)
|
)
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, edited_conditioning
|
x, sigma, edited_conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
context.clear_requests(cleanup=True)
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user