mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(diffusers_pipeline): remove unused pipeline methods 🚮 (#4175)
This commit is contained in:
commit
f86d388786
@ -1,26 +1,23 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from contextlib import contextmanager, ContextDecorator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, Optional, get_args
|
from typing import Literal, Optional, get_args
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
|
|
||||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
||||||
from .model import UNetField, VaeField
|
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
from .image import ImageOutput
|
||||||
|
from .model import UNetField, VaeField
|
||||||
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
|
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
@ -193,8 +190,6 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
precision="float16" if dtype == torch.float16 else "float32",
|
|
||||||
execution_device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield OldModelInfo(
|
yield OldModelInfo(
|
||||||
|
@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel
|
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor2_0,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
from .compel import ConditioningField
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .image import ImageOutput
|
||||||
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from ...backend.model_management import ModelPatcher
|
from ...backend.model_management import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.model_management import ModelPatcher
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .image import ImageOutput
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
||||||
|
|
||||||
from diffusers.models.attention_processor import (
|
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prep_control_data(
|
def prep_control_data(
|
||||||
|
@ -1,25 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
invokeai.backend.generator.img2img descends from .generator
|
invokeai.backend.generator.img2img descends from .generator
|
||||||
"""
|
"""
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from diffusers import logging
|
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
|
||||||
ConditioningData,
|
|
||||||
PostprocessingSettings,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
)
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
|
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
|
||||||
super().__init__(model, precision)
|
|
||||||
self.init_latent = None # by get_noise()
|
|
||||||
|
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
self,
|
self,
|
||||||
sampler,
|
sampler,
|
||||||
@ -42,51 +28,4 @@ class Img2Img(Generator):
|
|||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it.
|
Return value depends on the seed at the time you call it.
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation")
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
||||||
pipeline.scheduler = sampler
|
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
uc,
|
|
||||||
c,
|
|
||||||
cfg_scale,
|
|
||||||
extra_conditioning_info,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=threshold,
|
|
||||||
warmup=warmup,
|
|
||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, seed: int):
|
|
||||||
# FIXME: use x_T for initial seeded noise
|
|
||||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
|
||||||
# necessary, which the x_T input might not match.
|
|
||||||
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
|
||||||
logging.set_verbosity_error() # quench safety check warnings
|
|
||||||
pipeline_output = pipeline.img2img_from_embeddings(
|
|
||||||
init_image,
|
|
||||||
strength,
|
|
||||||
steps,
|
|
||||||
conditioning_data,
|
|
||||||
noise_func=self.get_noise_like,
|
|
||||||
callback=step_callback,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
|
||||||
|
|
||||||
return make_image
|
|
||||||
|
|
||||||
def get_noise_like(self, like: torch.Tensor):
|
|
||||||
device = like.device
|
|
||||||
x = torch.randn_like(like, device=device)
|
|
||||||
if self.perlin > 0.0:
|
|
||||||
shape = like.shape
|
|
||||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
|
||||||
return x
|
|
||||||
|
@ -377,3 +377,11 @@ class Inpaint(Img2Img):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return corrected_result
|
return corrected_result
|
||||||
|
|
||||||
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
|
device = like.device
|
||||||
|
x = torch.randn_like(like, device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
shape = like.shape
|
||||||
|
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||||
|
return x
|
||||||
|
@ -4,25 +4,21 @@ import dataclasses
|
|||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import einops
|
||||||
from accelerate.utils import set_seed
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
@ -31,21 +27,20 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
from diffusers.utils import PIL_INTERPOLATION
|
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
|
from pydantic import Field
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..util import CPU_DEVICE, normalize_device
|
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
from ..util import normalize_device
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -289,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
_model_group: ModelGroup
|
|
||||||
|
|
||||||
ID_LENGTH = 8
|
ID_LENGTH = 8
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -303,9 +296,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
|
||||||
control_model: ControlNetModel = None,
|
control_model: ControlNetModel = None,
|
||||||
execution_device: Optional[torch.device] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -330,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# control_model=control_model,
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
|
||||||
self._model_group.install(*self._submodels)
|
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
@ -368,72 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
self.disable_attention_slicing()
|
self.disable_attention_slicing()
|
||||||
|
|
||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
|
||||||
# overridden method; types match the superclass.
|
|
||||||
if torch_device is None:
|
|
||||||
return self
|
|
||||||
self._model_group.set_device(torch.device(torch_device))
|
|
||||||
self._model_group.ready()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return self._model_group.execution_device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
submodels = []
|
|
||||||
for name in module_names.keys():
|
|
||||||
if hasattr(self, name):
|
|
||||||
value = getattr(self, name)
|
|
||||||
else:
|
|
||||||
value = getattr(self.config, name)
|
|
||||||
if isinstance(value, torch.nn.Module):
|
|
||||||
submodels.append(value)
|
|
||||||
return submodels
|
|
||||||
|
|
||||||
def image_from_embeddings(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
num_inference_steps: int,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
*,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
run_id=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
r"""
|
|
||||||
Function invoked when calling the pipeline for generation.
|
|
||||||
|
|
||||||
:param conditioning_data:
|
|
||||||
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
|
||||||
image generation. Can be used to tweak the same generation with different prompts.
|
|
||||||
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
|
||||||
image at the expense of slower inference.
|
|
||||||
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
|
||||||
:param callback:
|
|
||||||
:param run_id:
|
|
||||||
"""
|
|
||||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
|
||||||
latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data,
|
|
||||||
noise=noise,
|
|
||||||
run_id=run_id,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.decode_latents(result_latents)
|
|
||||||
output = InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
images=image,
|
|
||||||
nsfw_content_detected=[],
|
|
||||||
attention_map_saver=result_attention_map_saver,
|
|
||||||
)
|
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
@ -450,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device("cpu")
|
scheduler_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
scheduler_device = self.unet.device
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||||
@ -504,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
(batch_size,),
|
(batch_size,),
|
||||||
timesteps[0],
|
timesteps[0],
|
||||||
dtype=timesteps.dtype,
|
dtype=timesteps.dtype,
|
||||||
device=self._model_group.device_for(self.unet),
|
device=self.unet.device,
|
||||||
)
|
)
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
@ -700,79 +622,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
|
||||||
self,
|
|
||||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
|
||||||
strength: float,
|
|
||||||
num_inference_steps: int,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
*,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
run_id=None,
|
|
||||||
noise_func=None,
|
|
||||||
seed=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
|
||||||
|
|
||||||
if init_image.dim() == 3:
|
|
||||||
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
initial_latents = self.non_noised_latents_from_image(
|
|
||||||
init_image,
|
|
||||||
device=self._model_group.device_for(self.unet),
|
|
||||||
dtype=self.unet.dtype,
|
|
||||||
)
|
|
||||||
if seed is not None:
|
|
||||||
set_seed(seed)
|
|
||||||
noise = noise_func(initial_latents)
|
|
||||||
|
|
||||||
return self.img2img_from_latents_and_embeddings(
|
|
||||||
initial_latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data,
|
|
||||||
strength,
|
|
||||||
noise,
|
|
||||||
run_id,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
def img2img_from_latents_and_embeddings(
|
|
||||||
self,
|
|
||||||
initial_latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
strength,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
run_id=None,
|
|
||||||
callback=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
|
||||||
latents=initial_latents
|
|
||||||
if strength < 1.0
|
|
||||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
timesteps=timesteps,
|
|
||||||
noise=noise,
|
|
||||||
run_id=run_id,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.decode_latents(result_latents)
|
|
||||||
output = InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
images=image,
|
|
||||||
nsfw_content_detected=[],
|
|
||||||
attention_map_saver=result_attention_maps,
|
|
||||||
)
|
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
|
||||||
|
|
||||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
@ -780,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device("cpu")
|
scheduler_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
scheduler_device = self.unet.device
|
||||||
|
|
||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||||
@ -806,7 +655,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise_func=None,
|
noise_func=None,
|
||||||
seed=None,
|
seed=None,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
device = self._model_group.device_for(self.unet)
|
device = self.unet.device
|
||||||
latents_dtype = self.unet.dtype
|
latents_dtype = self.unet.dtype
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
@ -877,42 +726,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
nsfw_content_detected=[],
|
nsfw_content_detected=[],
|
||||||
attention_map_saver=result_attention_maps,
|
attention_map_saver=result_attention_maps,
|
||||||
)
|
)
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
return output
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||||
init_image = init_image.to(device=device, dtype=dtype)
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self._model_group.load(self.vae)
|
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
return init_latents
|
return init_latents
|
||||||
|
|
||||||
def check_for_safety(self, output, dtype):
|
|
||||||
with torch.inference_mode():
|
|
||||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
|
||||||
screened_attention_map_saver = None
|
|
||||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
|
||||||
screened_attention_map_saver = output.attention_map_saver
|
|
||||||
return InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
screened_images,
|
|
||||||
has_nsfw_concept,
|
|
||||||
# block the attention maps if NSFW content is detected
|
|
||||||
attention_map_saver=screened_attention_map_saver,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_safety_checker(self, image, device=None, dtype=None):
|
|
||||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
|
||||||
if self.safety_checker is not None:
|
|
||||||
device = self._model_group.device_for(self.safety_checker)
|
|
||||||
return super().run_safety_checker(image, device, dtype)
|
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
|
||||||
self._model_group.load(self.vae)
|
|
||||||
return super().decode_latents(latents)
|
|
||||||
|
|
||||||
def debug_latents(self, latents, msg):
|
def debug_latents(self, latents, msg):
|
||||||
from invokeai.backend.image_util import debug_image
|
from invokeai.backend.image_util import debug_image
|
||||||
|
|
||||||
|
@ -1,253 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
import weakref
|
|
||||||
from abc import ABCMeta, abstractmethod
|
|
||||||
from collections.abc import MutableMapping
|
|
||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import send_to_device
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
OFFLOAD_DEVICE = torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class _NoModel:
|
|
||||||
"""Symbol that indicates no model is loaded.
|
|
||||||
|
|
||||||
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
|
||||||
type-checkable.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<NO MODEL>"
|
|
||||||
|
|
||||||
|
|
||||||
NO_MODEL = _NoModel()
|
|
||||||
|
|
||||||
|
|
||||||
class ModelGroup(metaclass=ABCMeta):
|
|
||||||
"""
|
|
||||||
A group of models.
|
|
||||||
|
|
||||||
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
|
||||||
e.g. its text encoder, U-net, VAE, etc.
|
|
||||||
|
|
||||||
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
|
||||||
:py:class:`torch.nn.Module` here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
self.execution_device = execution_device
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
"""Add models to this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def uninstall(self, models: torch.nn.Module):
|
|
||||||
"""Remove models from this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def uninstall_all(self):
|
|
||||||
"""Remove all models from this group."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self, model: torch.nn.Module):
|
|
||||||
"""Load this model to the execution device."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def offload_current(self):
|
|
||||||
"""Offload the current model(s) from the execution device."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def ready(self):
|
|
||||||
"""Ready this group for use."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
"""Change which device models from this group will execute on."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def device_for(self, model) -> torch.device:
|
|
||||||
"""Get the device the given model will execute on.
|
|
||||||
|
|
||||||
The model should already be a member of this group.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __contains__(self, model):
|
|
||||||
"""Check if the model is a member of this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
|
||||||
|
|
||||||
|
|
||||||
class LazilyLoadedModelGroup(ModelGroup):
|
|
||||||
"""
|
|
||||||
Only one model from this group is loaded on the GPU at a time.
|
|
||||||
|
|
||||||
Running the forward method of a model will displace the previously-loaded model,
|
|
||||||
offloading it to CPU.
|
|
||||||
|
|
||||||
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
|
||||||
you will need to explicitly load it with :py:method:`.load(model)`.
|
|
||||||
|
|
||||||
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
|
||||||
to the appropriate execution device, as long as they are positional arguments and not keyword
|
|
||||||
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
|
||||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
super().__init__(execution_device)
|
|
||||||
self._hooks = weakref.WeakKeyDictionary()
|
|
||||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
||||||
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
|
||||||
|
|
||||||
def uninstall(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
hook = self._hooks.pop(model)
|
|
||||||
hook.remove()
|
|
||||||
if self.is_current_model(model):
|
|
||||||
# no longer hooked by this object, so don't claim to manage it
|
|
||||||
self.clear_current_model()
|
|
||||||
|
|
||||||
def uninstall_all(self):
|
|
||||||
self.uninstall(*self._hooks.keys())
|
|
||||||
|
|
||||||
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
|
||||||
self.load(module)
|
|
||||||
if len(forward_input) == 0:
|
|
||||||
warnings.warn(
|
|
||||||
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
return send_to_device(forward_input, self.execution_device)
|
|
||||||
|
|
||||||
def load(self, module):
|
|
||||||
if not self.is_current_model(module):
|
|
||||||
self.offload_current()
|
|
||||||
self._load(module)
|
|
||||||
|
|
||||||
def offload_current(self):
|
|
||||||
module = self._current_model_ref()
|
|
||||||
if module is not NO_MODEL:
|
|
||||||
module.to(OFFLOAD_DEVICE)
|
|
||||||
self.clear_current_model()
|
|
||||||
|
|
||||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
|
||||||
module = module.to(self.execution_device)
|
|
||||||
self.set_current_model(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
def is_current_model(self, model: torch.nn.Module) -> bool:
|
|
||||||
"""Is the given model the one currently loaded on the execution device?"""
|
|
||||||
return self._current_model_ref() is model
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
"""Are none of this group's models loaded on the execution device?"""
|
|
||||||
return self._current_model_ref() is NO_MODEL
|
|
||||||
|
|
||||||
def set_current_model(self, value):
|
|
||||||
self._current_model_ref = weakref.ref(value)
|
|
||||||
|
|
||||||
def clear_current_model(self):
|
|
||||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
if device == self.execution_device:
|
|
||||||
return
|
|
||||||
self.execution_device = device
|
|
||||||
current = self._current_model_ref()
|
|
||||||
if current is not NO_MODEL:
|
|
||||||
current.to(device)
|
|
||||||
|
|
||||||
def device_for(self, model):
|
|
||||||
if model not in self:
|
|
||||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
|
||||||
return self.execution_device # this implementation only dispatches to one device
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
pass # always ready to load on-demand
|
|
||||||
|
|
||||||
def __contains__(self, model):
|
|
||||||
return model in self._hooks
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"<{self.__class__.__name__} object at {id(self):x}: "
|
|
||||||
f"current_model={type(self._current_model_ref()).__name__} >"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FullyLoadedModelGroup(ModelGroup):
|
|
||||||
"""
|
|
||||||
A group of models without any implicit loading or unloading.
|
|
||||||
|
|
||||||
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_models: weakref.WeakSet
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
super().__init__(execution_device)
|
|
||||||
self._models = weakref.WeakSet()
|
|
||||||
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._models.add(model)
|
|
||||||
model.to(self.execution_device)
|
|
||||||
|
|
||||||
def uninstall(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._models.remove(model)
|
|
||||||
|
|
||||||
def uninstall_all(self):
|
|
||||||
self.uninstall(*self._models)
|
|
||||||
|
|
||||||
def load(self, model):
|
|
||||||
model.to(self.execution_device)
|
|
||||||
|
|
||||||
def offload_current(self):
|
|
||||||
for model in self._models:
|
|
||||||
model.to(OFFLOAD_DEVICE)
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
for model in self._models:
|
|
||||||
self.load(model)
|
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
self.execution_device = device
|
|
||||||
for model in self._models:
|
|
||||||
if model.device != OFFLOAD_DEVICE:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
def device_for(self, model):
|
|
||||||
if model not in self:
|
|
||||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
|
||||||
return self.execution_device # this implementation only dispatches to one device
|
|
||||||
|
|
||||||
def __contains__(self, model):
|
|
||||||
return model in self._models
|
|
Loading…
Reference in New Issue
Block a user