mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove legacy/unused code
This commit is contained in:
parent
da0184a786
commit
a7e44678fb
@ -1,253 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Literal, Optional, get_args
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
|
||||||
|
|
||||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
from ...backend.model_management import ModelPatcher, BaseModelType
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
||||||
from .model import UNetField, VaeField
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
|
||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
|
||||||
|
|
||||||
|
|
||||||
from .latent import get_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class OldModelContext(ContextDecorator):
|
|
||||||
model: StableDiffusionGeneratorPipeline
|
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class OldModelInfo:
|
|
||||||
name: str
|
|
||||||
hash: str
|
|
||||||
context: OldModelContext
|
|
||||||
|
|
||||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
|
||||||
self.name = name
|
|
||||||
self.hash = hash
|
|
||||||
self.context = OldModelContext(
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(BaseInvocation):
|
|
||||||
"""Generates an image using inpaint."""
|
|
||||||
|
|
||||||
type: Literal["inpaint"] = "inpaint"
|
|
||||||
|
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
|
||||||
seed: int = Field(
|
|
||||||
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
|
|
||||||
)
|
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
|
||||||
width: int = Field(
|
|
||||||
default=512,
|
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The width of the resulting image",
|
|
||||||
)
|
|
||||||
height: int = Field(
|
|
||||||
default=512,
|
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The height of the resulting image",
|
|
||||||
)
|
|
||||||
cfg_scale: float = Field(
|
|
||||||
default=7.5,
|
|
||||||
ge=1,
|
|
||||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
|
||||||
)
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
|
||||||
unet: UNetField = Field(default=None, description="UNet model")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae model")
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
|
||||||
fit: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
mask: Optional[ImageField] = Field(description="The mask")
|
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
|
||||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
|
||||||
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
|
|
||||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
|
||||||
infill_method: INFILL_METHODS = Field(
|
|
||||||
default=DEFAULT_INFILL_METHOD,
|
|
||||||
description="The method used to infill empty regions (px)",
|
|
||||||
)
|
|
||||||
inpaint_width: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The width of the inpaint region (px)",
|
|
||||||
)
|
|
||||||
inpaint_height: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The height of the inpaint region (px)",
|
|
||||||
)
|
|
||||||
inpaint_fill: Optional[ColorField] = Field(
|
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
|
||||||
description="The solid infill method color",
|
|
||||||
)
|
|
||||||
inpaint_replace: float = Field(
|
|
||||||
default=0.0,
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
description="The amount by which to replace masked areas with latent noise",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_conditioning(self, context, unet):
|
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
|
||||||
extra_conditioning_info = c.extra_conditioning
|
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
return (uc, c, extra_conditioning_info)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def load_model_old_way(self, context, scheduler):
|
|
||||||
def _lora_loader():
|
|
||||||
for lora in self.unet.loras:
|
|
||||||
lora_info = context.services.model_manager.get_model(
|
|
||||||
**lora.dict(exclude={"weight"}),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
|
||||||
**self.unet.unet.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
|
||||||
**self.vae.vae.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
|
||||||
device = context.services.model_manager.mgr.cache.execution_device
|
|
||||||
dtype = context.services.model_manager.mgr.cache.precision
|
|
||||||
|
|
||||||
pipeline = StableDiffusionGeneratorPipeline(
|
|
||||||
vae=vae,
|
|
||||||
text_encoder=None,
|
|
||||||
tokenizer=None,
|
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
|
||||||
safety_checker=None,
|
|
||||||
feature_extractor=None,
|
|
||||||
requires_safety_checker=False,
|
|
||||||
precision="float16" if dtype == torch.float16 else "float32",
|
|
||||||
execution_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield OldModelInfo(
|
|
||||||
name=self.unet.unet.model_name,
|
|
||||||
hash="<NO-HASH>",
|
|
||||||
model=pipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
|
||||||
context=context,
|
|
||||||
scheduler_info=self.unet.scheduler,
|
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.load_model_old_way(context, scheduler) as model:
|
|
||||||
conditioning = self.get_conditioning(context, model.context.model.unet)
|
|
||||||
|
|
||||||
outputs = Inpaint(model).generate(
|
|
||||||
conditioning=conditioning,
|
|
||||||
scheduler=scheduler,
|
|
||||||
init_image=image,
|
|
||||||
mask_image=mask,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id, self.unet.unet.base_model),
|
|
||||||
**self.dict(
|
|
||||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
||||||
# each time it is called. We only need the first one.
|
|
||||||
generator_output = next(outputs)
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=generator_output.image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
@ -4,7 +4,6 @@ from invokeai.app.models.exceptions import CanceledException
|
|||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
@ -118,57 +117,3 @@ def stable_diffusion_step_callback(
|
|||||||
step=intermediate_state.step,
|
step=intermediate_state.step,
|
||||||
total_steps=node["steps"],
|
total_steps=node["steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_xl_step_callback(
|
|
||||||
context: InvocationContext,
|
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
sample,
|
|
||||||
step,
|
|
||||||
total_steps,
|
|
||||||
):
|
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
|
||||||
raise CanceledException
|
|
||||||
|
|
||||||
sdxl_latent_rgb_factors = torch.tensor(
|
|
||||||
[
|
|
||||||
# R G B
|
|
||||||
[0.3816, 0.4930, 0.5320],
|
|
||||||
[-0.3753, 0.1631, 0.1739],
|
|
||||||
[0.1770, 0.3588, -0.2048],
|
|
||||||
[-0.4350, -0.2644, -0.4289],
|
|
||||||
],
|
|
||||||
dtype=sample.dtype,
|
|
||||||
device=sample.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
sdxl_smooth_matrix = torch.tensor(
|
|
||||||
[
|
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
|
||||||
# [ 0.1285, 0.2948, 0.1285],
|
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
|
||||||
[0.0358, 0.0964, 0.0358],
|
|
||||||
[0.0964, 0.4711, 0.0964],
|
|
||||||
[0.0358, 0.0964, 0.0358],
|
|
||||||
],
|
|
||||||
dtype=sample.dtype,
|
|
||||||
device=sample.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
|
||||||
|
|
||||||
(width, height) = image.size
|
|
||||||
width *= 8
|
|
||||||
height *= 8
|
|
||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
|
||||||
step=step,
|
|
||||||
total_steps=total_steps,
|
|
||||||
)
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
|
|
||||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||||
from .model_management.models import SilenceWarnings
|
from .model_management.models import SilenceWarnings
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for the invokeai.generator package
|
|
||||||
"""
|
|
||||||
from .base import (
|
|
||||||
InvokeAIGenerator,
|
|
||||||
InvokeAIGeneratorBasicParams,
|
|
||||||
InvokeAIGeneratorOutput,
|
|
||||||
Img2Img,
|
|
||||||
Inpaint,
|
|
||||||
Generator,
|
|
||||||
)
|
|
||||||
from .inpaint import infill_methods
|
|
@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
Base class for invokeai.backend.generator.*
|
|
||||||
including img2img, txt2img, and inpaint
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import dataclasses
|
|
||||||
import diffusers
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import traceback
|
|
||||||
from abc import ABCMeta
|
|
||||||
from argparse import Namespace
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageChops, ImageFilter
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from tqdm import trange
|
|
||||||
from typing import Callable, List, Iterator, Optional, Type, Union
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..image_util import configure_model_padding
|
|
||||||
from ..util.util import rand_perlin_2d
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
||||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
|
||||||
|
|
||||||
downsampling = 8
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InvokeAIGeneratorBasicParams:
|
|
||||||
seed: Optional[int] = None
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
cfg_scale: float = 7.5
|
|
||||||
steps: int = 20
|
|
||||||
ddim_eta: float = 0.0
|
|
||||||
scheduler: str = "ddim"
|
|
||||||
precision: str = "float16"
|
|
||||||
perlin: float = 0.0
|
|
||||||
threshold: float = 0.0
|
|
||||||
seamless: bool = False
|
|
||||||
seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
|
|
||||||
h_symmetry_time_pct: Optional[float] = None
|
|
||||||
v_symmetry_time_pct: Optional[float] = None
|
|
||||||
variation_amount: float = 0.0
|
|
||||||
with_variations: list = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InvokeAIGeneratorOutput:
|
|
||||||
"""
|
|
||||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
|
||||||
operation, including the image, its seed, the model name used to generate the image
|
|
||||||
and the model hash, as well as all the generate() parameters that went into
|
|
||||||
generating the image (in .params, also available as attributes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
image: Image.Image
|
|
||||||
seed: int
|
|
||||||
model_hash: str
|
|
||||||
attention_maps_images: List[Image.Image]
|
|
||||||
params: Namespace
|
|
||||||
|
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
|
||||||
# old code that calls Generate will continue to work.
|
|
||||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_info: dict,
|
|
||||||
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.model_info = model_info
|
|
||||||
self.params = params
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
conditioning: tuple,
|
|
||||||
scheduler,
|
|
||||||
callback: Optional[Callable] = None,
|
|
||||||
step_callback: Optional[Callable] = None,
|
|
||||||
iterations: int = 1,
|
|
||||||
**keyword_args,
|
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
"""
|
|
||||||
Return an iterator across the indicated number of generations.
|
|
||||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
|
||||||
object. Use like this:
|
|
||||||
|
|
||||||
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
|
||||||
for result in outputs:
|
|
||||||
print(result.image, result.seed)
|
|
||||||
|
|
||||||
In the typical case of wanting to get just a single image, iterations
|
|
||||||
defaults to 1 and do:
|
|
||||||
|
|
||||||
output = next(txt2img.generate(prompt='banana sushi')
|
|
||||||
|
|
||||||
Pass None to get an infinite iterator.
|
|
||||||
|
|
||||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
|
||||||
for o in outputs:
|
|
||||||
print(o.image, o.seed)
|
|
||||||
|
|
||||||
"""
|
|
||||||
generator_args = dataclasses.asdict(self.params)
|
|
||||||
generator_args.update(keyword_args)
|
|
||||||
|
|
||||||
model_info = self.model_info
|
|
||||||
model_name = model_info.name
|
|
||||||
model_hash = model_info.hash
|
|
||||||
with model_info.context as model:
|
|
||||||
gen_class = self._generator_class()
|
|
||||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
|
||||||
if self.params.variation_amount > 0:
|
|
||||||
generator.set_variation(
|
|
||||||
generator_args.get("seed"),
|
|
||||||
generator_args.get("variation_amount"),
|
|
||||||
generator_args.get("with_variations"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
|
||||||
for component in [model.unet, model.vae]:
|
|
||||||
configure_model_padding(
|
|
||||||
component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
configure_model_padding(
|
|
||||||
model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
|
||||||
)
|
|
||||||
|
|
||||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
|
||||||
for i in iteration_count:
|
|
||||||
results = generator.generate(
|
|
||||||
conditioning=conditioning,
|
|
||||||
step_callback=step_callback,
|
|
||||||
sampler=scheduler,
|
|
||||||
**generator_args,
|
|
||||||
)
|
|
||||||
output = InvokeAIGeneratorOutput(
|
|
||||||
image=results[0][0],
|
|
||||||
seed=results[0][1],
|
|
||||||
attention_maps_images=results[0][2],
|
|
||||||
model_hash=model_hash,
|
|
||||||
params=Namespace(model_name=model_name, **generator_args),
|
|
||||||
)
|
|
||||||
if callback:
|
|
||||||
callback(output)
|
|
||||||
yield output
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def schedulers(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Return list of all the schedulers that we currently handle.
|
|
||||||
"""
|
|
||||||
return list(SCHEDULER_MAP.keys())
|
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
|
||||||
return generator_class(model, self.params.precision)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls) -> Type[Generator]:
|
|
||||||
"""
|
|
||||||
In derived classes return the name of the generator to apply.
|
|
||||||
If you don't override will return the name of the derived
|
|
||||||
class, which nicely parallels the generator class names.
|
|
||||||
"""
|
|
||||||
return Generator
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
class Img2Img(InvokeAIGenerator):
|
|
||||||
def generate(
|
|
||||||
self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
|
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(init_image=init_image, strength=strength, **keyword_args)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .img2img import Img2Img
|
|
||||||
|
|
||||||
return Img2Img
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
|
||||||
class Inpaint(Img2Img):
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
|
||||||
# Seam settings - when 0, doesn't fill seam
|
|
||||||
seam_size: int = 96,
|
|
||||||
seam_blur: int = 16,
|
|
||||||
seam_strength: float = 0.7,
|
|
||||||
seam_steps: int = 30,
|
|
||||||
tile_size: int = 32,
|
|
||||||
inpaint_replace=False,
|
|
||||||
infill_method=None,
|
|
||||||
inpaint_width=None,
|
|
||||||
inpaint_height=None,
|
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
|
||||||
**keyword_args,
|
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(
|
|
||||||
mask_image=mask_image,
|
|
||||||
seam_size=seam_size,
|
|
||||||
seam_blur=seam_blur,
|
|
||||||
seam_strength=seam_strength,
|
|
||||||
seam_steps=seam_steps,
|
|
||||||
tile_size=tile_size,
|
|
||||||
inpaint_replace=inpaint_replace,
|
|
||||||
infill_method=infill_method,
|
|
||||||
inpaint_width=inpaint_width,
|
|
||||||
inpaint_height=inpaint_height,
|
|
||||||
inpaint_fill=inpaint_fill,
|
|
||||||
**keyword_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .inpaint import Inpaint
|
|
||||||
|
|
||||||
return Inpaint
|
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
|
||||||
downsampling_factor: int
|
|
||||||
latent_channels: int
|
|
||||||
precision: str
|
|
||||||
model: DiffusionPipeline
|
|
||||||
|
|
||||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
|
||||||
self.model = model
|
|
||||||
self.precision = precision
|
|
||||||
self.seed = None
|
|
||||||
self.latent_channels = model.unet.config.in_channels
|
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
|
||||||
self.perlin = 0.0
|
|
||||||
self.threshold = 0
|
|
||||||
self.variation_amount = 0
|
|
||||||
self.with_variations = []
|
|
||||||
self.use_mps_noise = False
|
|
||||||
self.free_gpu_mem = None
|
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
|
||||||
def get_make_image(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
|
||||||
Return value depends on the seed at the time you call it
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
|
||||||
|
|
||||||
def set_variation(self, seed, variation_amount, with_variations):
|
|
||||||
self.seed = seed
|
|
||||||
self.variation_amount = variation_amount
|
|
||||||
self.with_variations = with_variations
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
sampler,
|
|
||||||
init_image=None,
|
|
||||||
iterations=1,
|
|
||||||
seed=None,
|
|
||||||
image_callback=None,
|
|
||||||
step_callback=None,
|
|
||||||
threshold=0.0,
|
|
||||||
perlin=0.0,
|
|
||||||
h_symmetry_time_pct=None,
|
|
||||||
v_symmetry_time_pct=None,
|
|
||||||
free_gpu_mem: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
scope = nullcontext
|
|
||||||
self.free_gpu_mem = free_gpu_mem
|
|
||||||
attention_maps_images = []
|
|
||||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
|
||||||
make_image = self.get_make_image(
|
|
||||||
sampler=sampler,
|
|
||||||
init_image=init_image,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
step_callback=step_callback,
|
|
||||||
threshold=threshold,
|
|
||||||
perlin=perlin,
|
|
||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
||||||
attention_maps_callback=attention_maps_callback,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
|
||||||
first_seed = seed
|
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
|
||||||
|
|
||||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
|
||||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
|
||||||
with scope(self.model.device.type):
|
|
||||||
for n in trange(iterations, desc="Generating"):
|
|
||||||
x_T = None
|
|
||||||
if self.variation_amount > 0:
|
|
||||||
set_seed(seed)
|
|
||||||
target_noise = self.get_noise(width, height)
|
|
||||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
|
||||||
elif initial_noise is not None:
|
|
||||||
# i.e. we specified particular variations
|
|
||||||
x_T = initial_noise
|
|
||||||
else:
|
|
||||||
set_seed(seed)
|
|
||||||
try:
|
|
||||||
x_T = self.get_noise(width, height)
|
|
||||||
except:
|
|
||||||
logger.error("An error occurred while getting initial noise")
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
|
||||||
image = make_image(x_T, seed)
|
|
||||||
|
|
||||||
results.append([image, seed, attention_maps_images])
|
|
||||||
|
|
||||||
if image_callback is not None:
|
|
||||||
attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
|
|
||||||
image_callback(
|
|
||||||
image,
|
|
||||||
seed,
|
|
||||||
first_seed=first_seed,
|
|
||||||
attention_maps_image=attention_maps_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
seed = self.new_seed()
|
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
|
||||||
clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
|
||||||
if clear_cuda_cache is not None:
|
|
||||||
clear_cuda_cache()
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def sample_to_image(self, samples) -> Image.Image:
|
|
||||||
"""
|
|
||||||
Given samples returned from a sampler, converts
|
|
||||||
it into a PIL Image
|
|
||||||
"""
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.model.decode_latents(samples)
|
|
||||||
return self.model.numpy_to_pil(image)[0]
|
|
||||||
|
|
||||||
def repaste_and_color_correct(
|
|
||||||
self,
|
|
||||||
result: Image.Image,
|
|
||||||
init_image: Image.Image,
|
|
||||||
init_mask: Image.Image,
|
|
||||||
mask_blur_radius: int = 8,
|
|
||||||
) -> Image.Image:
|
|
||||||
if init_image is None or init_mask is None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Get the original alpha channel of the mask if there is one.
|
|
||||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
|
||||||
pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
|
|
||||||
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
|
||||||
|
|
||||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
|
||||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
|
||||||
init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
|
|
||||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Get numpy version of result
|
|
||||||
np_image = np.asarray(result, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Mask and calculate mean and standard deviation
|
|
||||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
|
||||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
|
||||||
np_image_masked = np_image[mask_pixels, :]
|
|
||||||
|
|
||||||
if np_init_rgb_pixels_masked.size > 0:
|
|
||||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
|
||||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
|
||||||
gen_means = np_image_masked.mean(axis=0)
|
|
||||||
gen_std = np_image_masked.std(axis=0)
|
|
||||||
|
|
||||||
# Color correct
|
|
||||||
np_matched_result = np_image.copy()
|
|
||||||
np_matched_result[:, :, :] = (
|
|
||||||
(
|
|
||||||
(
|
|
||||||
(np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
|
|
||||||
/ gen_std[None, None, :]
|
|
||||||
)
|
|
||||||
* init_std[None, None, :]
|
|
||||||
+ init_means[None, None, :]
|
|
||||||
)
|
|
||||||
.clip(0, 255)
|
|
||||||
.astype(np.uint8)
|
|
||||||
)
|
|
||||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
|
||||||
else:
|
|
||||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
|
||||||
|
|
||||||
# Blur the mask out (into init image) by specified amount
|
|
||||||
if mask_blur_radius > 0:
|
|
||||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
|
||||||
nmd = cv2.erode(
|
|
||||||
nm,
|
|
||||||
kernel=np.ones((3, 3), dtype=np.uint8),
|
|
||||||
iterations=int(mask_blur_radius / 2),
|
|
||||||
)
|
|
||||||
pmd = Image.fromarray(nmd, mode="L")
|
|
||||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
|
||||||
else:
|
|
||||||
blurred_init_mask = pil_init_mask
|
|
||||||
|
|
||||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
|
||||||
|
|
||||||
# Paste original on color-corrected generation (using blurred mask)
|
|
||||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
|
||||||
return matched_result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sample_to_lowres_estimated_image(samples):
|
|
||||||
# origingally adapted from code by @erucipe and @keturn here:
|
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
|
||||||
|
|
||||||
# these updated numbers for v1.5 are from @torridgristle
|
|
||||||
v1_5_latent_rgb_factors = torch.tensor(
|
|
||||||
[
|
|
||||||
# R G B
|
|
||||||
[0.3444, 0.1385, 0.0670], # L1
|
|
||||||
[0.1247, 0.4027, 0.1494], # L2
|
|
||||||
[-0.3192, 0.2513, 0.2103], # L3
|
|
||||||
[-0.1307, -0.1874, -0.7445], # L4
|
|
||||||
],
|
|
||||||
dtype=samples.dtype,
|
|
||||||
device=samples.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
|
||||||
latents_ubyte = (
|
|
||||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
|
||||||
).cpu()
|
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
|
||||||
|
|
||||||
def generate_initial_noise(self, seed, width, height):
|
|
||||||
initial_noise = None
|
|
||||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
|
||||||
# use fixed initial noise plus random noise per iteration
|
|
||||||
set_seed(seed)
|
|
||||||
initial_noise = self.get_noise(width, height)
|
|
||||||
for v_seed, v_weight in self.with_variations:
|
|
||||||
seed = v_seed
|
|
||||||
set_seed(seed)
|
|
||||||
next_noise = self.get_noise(width, height)
|
|
||||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
|
||||||
if self.variation_amount > 0:
|
|
||||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
|
||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
|
||||||
return (seed, initial_noise)
|
|
||||||
|
|
||||||
def get_perlin_noise(self, width, height):
|
|
||||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
|
||||||
input_channels = min(self.latent_channels, 4)
|
|
||||||
# round up to the nearest block of 8
|
|
||||||
temp_width = int((width + 7) / 8) * 8
|
|
||||||
temp_height = int((height + 7) / 8) * 8
|
|
||||||
noise = torch.stack(
|
|
||||||
[
|
|
||||||
rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
|
|
||||||
for _ in range(input_channels)
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
).to(self.model.device)
|
|
||||||
return noise[0:4, 0:height, 0:width]
|
|
||||||
|
|
||||||
def new_seed(self):
|
|
||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
|
||||||
return self.seed
|
|
||||||
|
|
||||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
|
||||||
"""
|
|
||||||
Spherical linear interpolation
|
|
||||||
Args:
|
|
||||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
|
||||||
v0 (np.ndarray): Starting vector
|
|
||||||
v1 (np.ndarray): Final vector
|
|
||||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
|
||||||
colineal. Not recommended to alter this.
|
|
||||||
Returns:
|
|
||||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
|
||||||
"""
|
|
||||||
inputs_are_torch = False
|
|
||||||
if not isinstance(v0, np.ndarray):
|
|
||||||
inputs_are_torch = True
|
|
||||||
v0 = v0.detach().cpu().numpy()
|
|
||||||
if not isinstance(v1, np.ndarray):
|
|
||||||
inputs_are_torch = True
|
|
||||||
v1 = v1.detach().cpu().numpy()
|
|
||||||
|
|
||||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
|
||||||
if np.abs(dot) > DOT_THRESHOLD:
|
|
||||||
v2 = (1 - t) * v0 + t * v1
|
|
||||||
else:
|
|
||||||
theta_0 = np.arccos(dot)
|
|
||||||
sin_theta_0 = np.sin(theta_0)
|
|
||||||
theta_t = theta_0 * t
|
|
||||||
sin_theta_t = np.sin(theta_t)
|
|
||||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
|
||||||
s1 = sin_theta_t / sin_theta_0
|
|
||||||
v2 = s0 * v0 + s1 * v1
|
|
||||||
|
|
||||||
if inputs_are_torch:
|
|
||||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
|
||||||
|
|
||||||
return v2
|
|
||||||
|
|
||||||
# this is a handy routine for debugging use. Given a generated sample,
|
|
||||||
# convert it into a PNG image and store it at the indicated path
|
|
||||||
def save_sample(self, sample, filepath):
|
|
||||||
image = self.sample_to_image(sample)
|
|
||||||
dirname = os.path.dirname(filepath) or "."
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
logger.info(f"creating directory {dirname}")
|
|
||||||
os.makedirs(dirname, exist_ok=True)
|
|
||||||
image.save(filepath, "PNG")
|
|
||||||
|
|
||||||
def torch_dtype(self) -> torch.dtype:
|
|
||||||
return torch.float16 if self.precision == "float16" else torch.float32
|
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
|
||||||
def get_noise(self, width, height):
|
|
||||||
device = self.model.device
|
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
|
||||||
input_channels = min(self.latent_channels, 4)
|
|
||||||
x = torch.randn(
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
input_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor,
|
|
||||||
],
|
|
||||||
dtype=self.torch_dtype(),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
if self.perlin > 0.0:
|
|
||||||
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
|
||||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
|
||||||
return x
|
|
@ -1,92 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.generator.img2img descends from .generator
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from diffusers import logging
|
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
|
||||||
ConditioningData,
|
|
||||||
PostprocessingSettings,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
)
|
|
||||||
from .base import Generator
|
|
||||||
|
|
||||||
|
|
||||||
class Img2Img(Generator):
|
|
||||||
def __init__(self, model, precision):
|
|
||||||
super().__init__(model, precision)
|
|
||||||
self.init_latent = None # by get_noise()
|
|
||||||
|
|
||||||
def get_make_image(
|
|
||||||
self,
|
|
||||||
sampler,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
init_image,
|
|
||||||
strength,
|
|
||||||
step_callback=None,
|
|
||||||
threshold=0.0,
|
|
||||||
warmup=0.2,
|
|
||||||
perlin=0.0,
|
|
||||||
h_symmetry_time_pct=None,
|
|
||||||
v_symmetry_time_pct=None,
|
|
||||||
attention_maps_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
|
||||||
Return value depends on the seed at the time you call it.
|
|
||||||
"""
|
|
||||||
self.perlin = perlin
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
||||||
pipeline.scheduler = sampler
|
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
uc,
|
|
||||||
c,
|
|
||||||
cfg_scale,
|
|
||||||
extra_conditioning_info,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=threshold,
|
|
||||||
warmup=warmup,
|
|
||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, seed: int):
|
|
||||||
# FIXME: use x_T for initial seeded noise
|
|
||||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
|
||||||
# necessary, which the x_T input might not match.
|
|
||||||
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
|
|
||||||
logging.set_verbosity_error() # quench safety check warnings
|
|
||||||
pipeline_output = pipeline.img2img_from_embeddings(
|
|
||||||
init_image,
|
|
||||||
strength,
|
|
||||||
steps,
|
|
||||||
conditioning_data,
|
|
||||||
noise_func=self.get_noise_like,
|
|
||||||
callback=step_callback,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
|
||||||
|
|
||||||
return make_image
|
|
||||||
|
|
||||||
def get_noise_like(self, like: torch.Tensor):
|
|
||||||
device = like.device
|
|
||||||
x = torch.randn_like(like, device=device)
|
|
||||||
if self.perlin > 0.0:
|
|
||||||
shape = like.shape
|
|
||||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
|
||||||
return x
|
|
@ -1,379 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.generator.inpaint descends from .generator
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Tuple, Union, Optional
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
|
||||||
|
|
||||||
from ..image_util import PatchMatch, debug_image
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import (
|
|
||||||
ConditioningData,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
image_resized_to_grid_as_tensor,
|
|
||||||
)
|
|
||||||
from .img2img import Img2Img
|
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
|
||||||
methods = [
|
|
||||||
"tile",
|
|
||||||
"solid",
|
|
||||||
]
|
|
||||||
if PatchMatch.patchmatch_available():
|
|
||||||
methods.insert(0, "patchmatch")
|
|
||||||
return methods
|
|
||||||
|
|
||||||
|
|
||||||
class Inpaint(Img2Img):
|
|
||||||
def __init__(self, model, precision):
|
|
||||||
self.inpaint_height = 0
|
|
||||||
self.inpaint_width = 0
|
|
||||||
self.enable_image_debugging = False
|
|
||||||
self.init_latent = None
|
|
||||||
self.pil_image = None
|
|
||||||
self.pil_mask = None
|
|
||||||
self.mask_blur_radius = 0
|
|
||||||
self.infill_method = None
|
|
||||||
super().__init__(model, precision)
|
|
||||||
|
|
||||||
# Outpaint support code
|
|
||||||
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
|
||||||
_nrows, _ncols, depth = image.shape
|
|
||||||
_strides = image.strides
|
|
||||||
|
|
||||||
nrows, _m = divmod(_nrows, height)
|
|
||||||
ncols, _n = divmod(_ncols, width)
|
|
||||||
if _m != 0 or _n != 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return np.lib.stride_tricks.as_strided(
|
|
||||||
np.ravel(image),
|
|
||||||
shape=(nrows, ncols, height, width, depth),
|
|
||||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
|
||||||
writeable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
|
||||||
if im.mode != "RGBA":
|
|
||||||
return im
|
|
||||||
|
|
||||||
# Skip patchmatch if patchmatch isn't available
|
|
||||||
if not PatchMatch.patchmatch_available():
|
|
||||||
return im
|
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
|
||||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
|
||||||
return im_patched
|
|
||||||
|
|
||||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
|
||||||
# Only fill if there's an alpha layer
|
|
||||||
if im.mode != "RGBA":
|
|
||||||
return im
|
|
||||||
|
|
||||||
a = np.asarray(im, dtype=np.uint8)
|
|
||||||
|
|
||||||
tile_size_tuple = (tile_size, tile_size)
|
|
||||||
|
|
||||||
# Get the image as tiles of a specified size
|
|
||||||
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
|
||||||
|
|
||||||
# Get the mask as tiles
|
|
||||||
tiles_mask = tiles[:, :, :, :, 3]
|
|
||||||
|
|
||||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
|
||||||
tmask_shape = tiles_mask.shape
|
|
||||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
|
||||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
|
||||||
tiles_mask = tiles_mask > 0
|
|
||||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
|
||||||
|
|
||||||
# Get RGB tiles in single array and filter by the mask
|
|
||||||
tshape = tiles.shape
|
|
||||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
|
||||||
filtered_tiles = tiles_all[tiles_mask]
|
|
||||||
|
|
||||||
if len(filtered_tiles) == 0:
|
|
||||||
return im
|
|
||||||
|
|
||||||
# Find all invalid tiles and replace with a random valid tile
|
|
||||||
replace_count = (tiles_mask == False).sum()
|
|
||||||
rng = np.random.default_rng(seed=seed)
|
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
|
||||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert back to an image
|
|
||||||
tiles_all = tiles_all.reshape(tshape)
|
|
||||||
tiles_all = tiles_all.swapaxes(1, 2)
|
|
||||||
st = tiles_all.reshape(
|
|
||||||
(
|
|
||||||
math.prod(tiles_all.shape[0:2]),
|
|
||||||
math.prod(tiles_all.shape[2:4]),
|
|
||||||
tiles_all.shape[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
si = Image.fromarray(st, mode="RGBA")
|
|
||||||
|
|
||||||
return si
|
|
||||||
|
|
||||||
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
|
|
||||||
npimg = np.asarray(mask, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Detect any partially transparent regions
|
|
||||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
|
||||||
|
|
||||||
# Detect hard edges
|
|
||||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
|
||||||
|
|
||||||
# Combine
|
|
||||||
npmask = npgradient + npedge
|
|
||||||
|
|
||||||
# Expand
|
|
||||||
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
|
|
||||||
|
|
||||||
new_mask = Image.fromarray(npmask)
|
|
||||||
|
|
||||||
if edge_blur > 0:
|
|
||||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
|
||||||
|
|
||||||
return ImageOps.invert(new_mask)
|
|
||||||
|
|
||||||
def seam_paint(
|
|
||||||
self,
|
|
||||||
im: Image.Image,
|
|
||||||
seam_size: int,
|
|
||||||
seam_blur: int,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
strength,
|
|
||||||
noise,
|
|
||||||
infill_method,
|
|
||||||
step_callback,
|
|
||||||
) -> Image.Image:
|
|
||||||
hard_mask = self.pil_image.split()[-1].copy()
|
|
||||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
|
||||||
|
|
||||||
make_image = self.get_make_image(
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
init_image=im.copy().convert("RGBA"),
|
|
||||||
mask_image=mask,
|
|
||||||
strength=strength,
|
|
||||||
mask_blur_radius=0,
|
|
||||||
seam_size=0,
|
|
||||||
step_callback=step_callback,
|
|
||||||
inpaint_width=im.width,
|
|
||||||
inpaint_height=im.height,
|
|
||||||
infill_method=infill_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
seam_noise = self.get_noise(im.width, im.height)
|
|
||||||
|
|
||||||
result = make_image(seam_noise, seed=None)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_make_image(
|
|
||||||
self,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
init_image: Union[Image.Image, torch.FloatTensor],
|
|
||||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
|
||||||
strength: float,
|
|
||||||
mask_blur_radius: int = 8,
|
|
||||||
# Seam settings - when 0, doesn't fill seam
|
|
||||||
seam_size: int = 96,
|
|
||||||
seam_blur: int = 16,
|
|
||||||
seam_strength: float = 0.7,
|
|
||||||
seam_steps: int = 30,
|
|
||||||
tile_size: int = 32,
|
|
||||||
step_callback=None,
|
|
||||||
inpaint_replace=False,
|
|
||||||
enable_image_debugging=False,
|
|
||||||
infill_method=None,
|
|
||||||
inpaint_width=None,
|
|
||||||
inpaint_height=None,
|
|
||||||
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
|
||||||
attention_maps_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and
|
|
||||||
the initial image + mask. Return value depends on the seed at
|
|
||||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.enable_image_debugging = enable_image_debugging
|
|
||||||
infill_method = infill_method or infill_methods()[0]
|
|
||||||
self.infill_method = infill_method
|
|
||||||
|
|
||||||
self.inpaint_width = inpaint_width
|
|
||||||
self.inpaint_height = inpaint_height
|
|
||||||
|
|
||||||
if isinstance(init_image, Image.Image):
|
|
||||||
self.pil_image = init_image.copy()
|
|
||||||
|
|
||||||
# Do infill
|
|
||||||
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
|
||||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
|
||||||
elif infill_method == "tile":
|
|
||||||
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
|
|
||||||
elif infill_method == "solid":
|
|
||||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
|
||||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
|
||||||
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
|
||||||
|
|
||||||
# Resize if requested for inpainting
|
|
||||||
if inpaint_width and inpaint_height:
|
|
||||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
|
||||||
|
|
||||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
|
||||||
|
|
||||||
# Create init tensor
|
|
||||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
|
||||||
|
|
||||||
if isinstance(mask_image, Image.Image):
|
|
||||||
self.pil_mask = mask_image.copy()
|
|
||||||
debug_image(
|
|
||||||
mask_image,
|
|
||||||
"mask_image BEFORE multiply with pil_image",
|
|
||||||
debug_status=self.enable_image_debugging,
|
|
||||||
)
|
|
||||||
|
|
||||||
init_alpha = self.pil_image.getchannel("A")
|
|
||||||
if mask_image.mode != "L":
|
|
||||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
|
||||||
mask_image = mask_image.convert("L")
|
|
||||||
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
|
||||||
self.pil_mask = mask_image
|
|
||||||
|
|
||||||
# Resize if requested for inpainting
|
|
||||||
if inpaint_width and inpaint_height:
|
|
||||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
|
||||||
|
|
||||||
debug_image(
|
|
||||||
mask_image,
|
|
||||||
"mask_image AFTER multiply with pil_image",
|
|
||||||
debug_status=self.enable_image_debugging,
|
|
||||||
)
|
|
||||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
||||||
else:
|
|
||||||
mask: torch.FloatTensor = mask_image
|
|
||||||
|
|
||||||
self.mask_blur_radius = mask_blur_radius
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
||||||
|
|
||||||
# todo: support cross-attention control
|
|
||||||
uc, c, _ = conditioning
|
|
||||||
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
|
|
||||||
pipeline.scheduler, eta=ddim_eta
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, seed: int):
|
|
||||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
|
||||||
init_image=init_image,
|
|
||||||
mask=1 - mask, # expects white means "paint here."
|
|
||||||
strength=strength,
|
|
||||||
num_inference_steps=steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
noise_func=self.get_noise_like,
|
|
||||||
callback=step_callback,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
||||||
|
|
||||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
|
||||||
|
|
||||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
|
||||||
if seam_size > 0:
|
|
||||||
old_image = self.pil_image or init_image
|
|
||||||
old_mask = self.pil_mask or mask_image
|
|
||||||
|
|
||||||
result = self.seam_paint(
|
|
||||||
result,
|
|
||||||
seam_size,
|
|
||||||
seam_blur,
|
|
||||||
seed,
|
|
||||||
seam_steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
seam_strength,
|
|
||||||
x_T,
|
|
||||||
infill_method,
|
|
||||||
step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Restore original settings
|
|
||||||
self.get_make_image(
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
old_image,
|
|
||||||
old_mask,
|
|
||||||
strength,
|
|
||||||
mask_blur_radius,
|
|
||||||
seam_size,
|
|
||||||
seam_blur,
|
|
||||||
seam_strength,
|
|
||||||
seam_steps,
|
|
||||||
tile_size,
|
|
||||||
step_callback,
|
|
||||||
inpaint_replace,
|
|
||||||
enable_image_debugging,
|
|
||||||
inpaint_width=inpaint_width,
|
|
||||||
inpaint_height=inpaint_height,
|
|
||||||
infill_method=infill_method,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
return make_image
|
|
||||||
|
|
||||||
def sample_to_image(self, samples) -> Image.Image:
|
|
||||||
gen_result = super().sample_to_image(samples).convert("RGB")
|
|
||||||
return self.postprocess_size_and_mask(gen_result)
|
|
||||||
|
|
||||||
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
|
||||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
|
||||||
|
|
||||||
# Resize if necessary
|
|
||||||
if self.inpaint_width and self.inpaint_height:
|
|
||||||
gen_result = gen_result.resize(self.pil_image.size)
|
|
||||||
|
|
||||||
if self.pil_image is None or self.pil_mask is None:
|
|
||||||
return gen_result
|
|
||||||
|
|
||||||
corrected_result = self.repaste_and_color_correct(
|
|
||||||
gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius
|
|
||||||
)
|
|
||||||
debug_image(
|
|
||||||
corrected_result,
|
|
||||||
"corrected_result",
|
|
||||||
debug_status=self.enable_image_debugging,
|
|
||||||
)
|
|
||||||
|
|
||||||
return corrected_result
|
|
@ -1,18 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
|
||||||
import math
|
|
||||||
import secrets
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
import inspect
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -23,15 +19,11 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
|||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
|
||||||
StableDiffusionImg2ImgPipeline,
|
|
||||||
)
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
StableDiffusionSafetyChecker,
|
StableDiffusionSafetyChecker,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
from diffusers.utils import PIL_INTERPOLATION
|
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
@ -45,7 +37,6 @@ from .diffusion import (
|
|||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -287,9 +278,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
_model_group: ModelGroup
|
|
||||||
|
|
||||||
ID_LENGTH = 8
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -328,9 +316,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# control_model=control_model,
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
|
||||||
self._model_group.install(*self._submodels)
|
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
@ -373,28 +358,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self.disable_attention_slicing()
|
self.disable_attention_slicing()
|
||||||
|
|
||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||||
# overridden method; types match the superclass.
|
raise Exception("Should not be called")
|
||||||
if torch_device is None:
|
|
||||||
return self
|
|
||||||
self._model_group.set_device(torch.device(torch_device))
|
|
||||||
self._model_group.ready()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self._model_group.execution_device
|
return self.unet.device
|
||||||
|
|
||||||
@property
|
|
||||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
submodels = []
|
|
||||||
for name in module_names.keys():
|
|
||||||
if hasattr(self, name):
|
|
||||||
value = getattr(self, name)
|
|
||||||
else:
|
|
||||||
value = getattr(self.config, name)
|
|
||||||
if isinstance(value, torch.nn.Module):
|
|
||||||
submodels.append(value)
|
|
||||||
return submodels
|
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -414,7 +382,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device("cpu")
|
scheduler_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
scheduler_device = self.unet.device
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||||
@ -511,7 +479,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
(batch_size,),
|
(batch_size,),
|
||||||
timesteps[0],
|
timesteps[0],
|
||||||
dtype=timesteps.dtype,
|
dtype=timesteps.dtype,
|
||||||
device=self._model_group.device_for(self.unet),
|
device=self.unet.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
@ -655,185 +623,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
|
||||||
self,
|
|
||||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
|
||||||
strength: float,
|
|
||||||
num_inference_steps: int,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
*,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
noise_func=None,
|
|
||||||
seed=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
|
||||||
|
|
||||||
if init_image.dim() == 3:
|
|
||||||
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
initial_latents = self.non_noised_latents_from_image(
|
|
||||||
init_image,
|
|
||||||
device=self._model_group.device_for(self.unet),
|
|
||||||
dtype=self.unet.dtype,
|
|
||||||
)
|
|
||||||
if seed is not None:
|
|
||||||
set_seed(seed)
|
|
||||||
noise = noise_func(initial_latents)
|
|
||||||
|
|
||||||
return self.img2img_from_latents_and_embeddings(
|
|
||||||
initial_latents,
|
|
||||||
num_inference_steps,
|
|
||||||
conditioning_data,
|
|
||||||
strength,
|
|
||||||
noise,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
|
||||||
|
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
|
||||||
scheduler_device = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
|
||||||
|
|
||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
|
||||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
|
||||||
num_inference_steps, strength, device=scheduler_device
|
|
||||||
)
|
|
||||||
# Workaround for low strength resulting in zero timesteps.
|
|
||||||
# TODO: submit upstream fix for zero-step img2img
|
|
||||||
if timesteps.numel() == 0:
|
|
||||||
timesteps = self.scheduler.timesteps[-1:]
|
|
||||||
adjusted_steps = timesteps.numel()
|
|
||||||
return timesteps, adjusted_steps
|
|
||||||
|
|
||||||
def inpaint_from_embeddings(
|
|
||||||
self,
|
|
||||||
init_image: torch.FloatTensor,
|
|
||||||
mask: torch.FloatTensor,
|
|
||||||
strength: float,
|
|
||||||
num_inference_steps: int,
|
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
*,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
|
||||||
noise_func=None,
|
|
||||||
seed=None,
|
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
|
||||||
device = self._model_group.device_for(self.unet)
|
|
||||||
latents_dtype = self.unet.dtype
|
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
|
||||||
|
|
||||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
|
||||||
mask = mask.to(device=device, dtype=latents_dtype)
|
|
||||||
|
|
||||||
if init_image.dim() == 3:
|
|
||||||
init_image = init_image.unsqueeze(0)
|
|
||||||
|
|
||||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
|
||||||
# because we have our own noise function
|
|
||||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
|
||||||
if seed is not None:
|
|
||||||
set_seed(seed)
|
|
||||||
noise = noise_func(init_image_latents)
|
|
||||||
|
|
||||||
if mask.dim() == 3:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR).to(
|
|
||||||
device=device, dtype=latents_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
guidance: List[Callable] = []
|
|
||||||
|
|
||||||
if is_inpainting_model(self.unet):
|
|
||||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
|
||||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
|
||||||
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
|
||||||
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
|
|
||||||
|
|
||||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
|
||||||
self._unet_forward, latent_mask, masked_latents
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
|
||||||
latents=init_image_latents
|
|
||||||
if strength < 1.0
|
|
||||||
else torch.zeros_like(
|
|
||||||
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
|
|
||||||
),
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
noise=noise,
|
|
||||||
timesteps=timesteps,
|
|
||||||
additional_guidance=guidance,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
image = self.decode_latents(result_latents)
|
|
||||||
output = InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
images=image,
|
|
||||||
nsfw_content_detected=[],
|
|
||||||
attention_map_saver=result_attention_maps,
|
|
||||||
)
|
|
||||||
return self.check_for_safety(output, dtype=self.unet.dtype)
|
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
|
||||||
init_image = init_image.to(device=device, dtype=dtype)
|
|
||||||
with torch.inference_mode():
|
|
||||||
self._model_group.load(self.vae)
|
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
|
||||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
|
||||||
return init_latents
|
|
||||||
|
|
||||||
def check_for_safety(self, output, dtype):
|
|
||||||
with torch.inference_mode():
|
|
||||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
|
||||||
screened_attention_map_saver = None
|
|
||||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
|
||||||
screened_attention_map_saver = output.attention_map_saver
|
|
||||||
return InvokeAIStableDiffusionPipelineOutput(
|
|
||||||
screened_images,
|
|
||||||
has_nsfw_concept,
|
|
||||||
# block the attention maps if NSFW content is detected
|
|
||||||
attention_map_saver=screened_attention_map_saver,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_safety_checker(self, image, device=None, dtype=None):
|
|
||||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
|
||||||
if self.safety_checker is not None:
|
|
||||||
device = self._model_group.device_for(self.safety_checker)
|
|
||||||
return super().run_safety_checker(image, device, dtype)
|
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
|
||||||
self._model_group.load(self.vae)
|
|
||||||
return super().decode_latents(latents)
|
|
||||||
|
|
||||||
def debug_latents(self, latents, msg):
|
|
||||||
from invokeai.backend.image_util import debug_image
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
|
||||||
for i, img in enumerate(decoded):
|
|
||||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
|
||||||
|
@ -295,7 +295,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if postprocessing_settings is not None:
|
if postprocessing_settings is not None:
|
||||||
percent_through = step_index / total_step_count
|
percent_through = step_index / total_step_count
|
||||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
|
||||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@ -516,63 +515,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
def apply_threshold(
|
|
||||||
self,
|
|
||||||
postprocessing_settings: PostprocessingSettings,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
percent_through: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
threshold = postprocessing_settings.threshold
|
|
||||||
warmup = postprocessing_settings.warmup
|
|
||||||
|
|
||||||
if percent_through < warmup:
|
|
||||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
|
||||||
else:
|
|
||||||
current_threshold = threshold
|
|
||||||
|
|
||||||
if current_threshold <= 0:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
maxval = latents.max().item()
|
|
||||||
minval = latents.min().item()
|
|
||||||
|
|
||||||
scale = 0.7 # default value from #395
|
|
||||||
|
|
||||||
if self.debug_thresholding:
|
|
||||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
|
||||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
|
||||||
logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})")
|
|
||||||
logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}")
|
|
||||||
logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold")
|
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
num_altered = 0
|
|
||||||
|
|
||||||
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
|
|
||||||
|
|
||||||
if maxval > current_threshold:
|
|
||||||
latents = torch.clone(latents)
|
|
||||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
|
||||||
num_altered += torch.count_nonzero(latents > maxval)
|
|
||||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
|
||||||
|
|
||||||
if minval < -current_threshold:
|
|
||||||
latents = torch.clone(latents)
|
|
||||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
|
||||||
num_altered += torch.count_nonzero(latents < minval)
|
|
||||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
|
||||||
|
|
||||||
if self.debug_thresholding:
|
|
||||||
logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})")
|
|
||||||
logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered")
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def apply_symmetry(
|
def apply_symmetry(
|
||||||
self,
|
self,
|
||||||
postprocessing_settings: PostprocessingSettings,
|
postprocessing_settings: PostprocessingSettings,
|
||||||
@ -634,18 +576,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.last_percent_through = percent_through
|
self.last_percent_through = percent_through
|
||||||
return latents.to(device=dev)
|
return latents.to(device=dev)
|
||||||
|
|
||||||
def estimate_percent_through(self, step_index, sigma):
|
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
|
||||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
|
||||||
# find the best possible index of the current sigma in the sigma sequence
|
|
||||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
|
||||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
|
||||||
# flip because sigmas[0] is for the fully denoised image
|
|
||||||
# percent_through must be <1
|
|
||||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
|
||||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
|
||||||
|
|
||||||
# todo: make this work
|
# todo: make this work
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||||
|
@ -1,253 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
import weakref
|
|
||||||
from abc import ABCMeta, abstractmethod
|
|
||||||
from collections.abc import MutableMapping
|
|
||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import send_to_device
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
OFFLOAD_DEVICE = torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
class _NoModel:
|
|
||||||
"""Symbol that indicates no model is loaded.
|
|
||||||
|
|
||||||
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
|
||||||
type-checkable.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<NO MODEL>"
|
|
||||||
|
|
||||||
|
|
||||||
NO_MODEL = _NoModel()
|
|
||||||
|
|
||||||
|
|
||||||
class ModelGroup(metaclass=ABCMeta):
|
|
||||||
"""
|
|
||||||
A group of models.
|
|
||||||
|
|
||||||
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
|
||||||
e.g. its text encoder, U-net, VAE, etc.
|
|
||||||
|
|
||||||
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
|
||||||
:py:class:`torch.nn.Module` here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
self.execution_device = execution_device
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
"""Add models to this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def uninstall(self, models: torch.nn.Module):
|
|
||||||
"""Remove models from this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def uninstall_all(self):
|
|
||||||
"""Remove all models from this group."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self, model: torch.nn.Module):
|
|
||||||
"""Load this model to the execution device."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def offload_current(self):
|
|
||||||
"""Offload the current model(s) from the execution device."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def ready(self):
|
|
||||||
"""Ready this group for use."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
"""Change which device models from this group will execute on."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def device_for(self, model) -> torch.device:
|
|
||||||
"""Get the device the given model will execute on.
|
|
||||||
|
|
||||||
The model should already be a member of this group.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __contains__(self, model):
|
|
||||||
"""Check if the model is a member of this group."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
|
||||||
|
|
||||||
|
|
||||||
class LazilyLoadedModelGroup(ModelGroup):
|
|
||||||
"""
|
|
||||||
Only one model from this group is loaded on the GPU at a time.
|
|
||||||
|
|
||||||
Running the forward method of a model will displace the previously-loaded model,
|
|
||||||
offloading it to CPU.
|
|
||||||
|
|
||||||
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
|
||||||
you will need to explicitly load it with :py:method:`.load(model)`.
|
|
||||||
|
|
||||||
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
|
||||||
to the appropriate execution device, as long as they are positional arguments and not keyword
|
|
||||||
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
|
||||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
super().__init__(execution_device)
|
|
||||||
self._hooks = weakref.WeakKeyDictionary()
|
|
||||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
||||||
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
|
||||||
|
|
||||||
def uninstall(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
hook = self._hooks.pop(model)
|
|
||||||
hook.remove()
|
|
||||||
if self.is_current_model(model):
|
|
||||||
# no longer hooked by this object, so don't claim to manage it
|
|
||||||
self.clear_current_model()
|
|
||||||
|
|
||||||
def uninstall_all(self):
|
|
||||||
self.uninstall(*self._hooks.keys())
|
|
||||||
|
|
||||||
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
|
||||||
self.load(module)
|
|
||||||
if len(forward_input) == 0:
|
|
||||||
warnings.warn(
|
|
||||||
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
return send_to_device(forward_input, self.execution_device)
|
|
||||||
|
|
||||||
def load(self, module):
|
|
||||||
if not self.is_current_model(module):
|
|
||||||
self.offload_current()
|
|
||||||
self._load(module)
|
|
||||||
|
|
||||||
def offload_current(self):
|
|
||||||
module = self._current_model_ref()
|
|
||||||
if module is not NO_MODEL:
|
|
||||||
module.to(OFFLOAD_DEVICE)
|
|
||||||
self.clear_current_model()
|
|
||||||
|
|
||||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
|
||||||
module = module.to(self.execution_device)
|
|
||||||
self.set_current_model(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
def is_current_model(self, model: torch.nn.Module) -> bool:
|
|
||||||
"""Is the given model the one currently loaded on the execution device?"""
|
|
||||||
return self._current_model_ref() is model
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
"""Are none of this group's models loaded on the execution device?"""
|
|
||||||
return self._current_model_ref() is NO_MODEL
|
|
||||||
|
|
||||||
def set_current_model(self, value):
|
|
||||||
self._current_model_ref = weakref.ref(value)
|
|
||||||
|
|
||||||
def clear_current_model(self):
|
|
||||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
if device == self.execution_device:
|
|
||||||
return
|
|
||||||
self.execution_device = device
|
|
||||||
current = self._current_model_ref()
|
|
||||||
if current is not NO_MODEL:
|
|
||||||
current.to(device)
|
|
||||||
|
|
||||||
def device_for(self, model):
|
|
||||||
if model not in self:
|
|
||||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
|
||||||
return self.execution_device # this implementation only dispatches to one device
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
pass # always ready to load on-demand
|
|
||||||
|
|
||||||
def __contains__(self, model):
|
|
||||||
return model in self._hooks
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"<{self.__class__.__name__} object at {id(self):x}: "
|
|
||||||
f"current_model={type(self._current_model_ref()).__name__} >"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FullyLoadedModelGroup(ModelGroup):
|
|
||||||
"""
|
|
||||||
A group of models without any implicit loading or unloading.
|
|
||||||
|
|
||||||
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_models: weakref.WeakSet
|
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
|
||||||
super().__init__(execution_device)
|
|
||||||
self._models = weakref.WeakSet()
|
|
||||||
|
|
||||||
def install(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._models.add(model)
|
|
||||||
model.to(self.execution_device)
|
|
||||||
|
|
||||||
def uninstall(self, *models: torch.nn.Module):
|
|
||||||
for model in models:
|
|
||||||
self._models.remove(model)
|
|
||||||
|
|
||||||
def uninstall_all(self):
|
|
||||||
self.uninstall(*self._models)
|
|
||||||
|
|
||||||
def load(self, model):
|
|
||||||
model.to(self.execution_device)
|
|
||||||
|
|
||||||
def offload_current(self):
|
|
||||||
for model in self._models:
|
|
||||||
model.to(OFFLOAD_DEVICE)
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
for model in self._models:
|
|
||||||
self.load(model)
|
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
|
||||||
self.execution_device = device
|
|
||||||
for model in self._models:
|
|
||||||
if model.device != OFFLOAD_DEVICE:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
def device_for(self, model):
|
|
||||||
if model not in self:
|
|
||||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
|
||||||
return self.execution_device # this implementation only dispatches to one device
|
|
||||||
|
|
||||||
def __contains__(self, model):
|
|
||||||
return model in self._models
|
|
Loading…
Reference in New Issue
Block a user