Merge branch 'main' into feat/refactor_generation_backend

This commit is contained in:
Sergey Borisov
2023-08-10 04:32:16 +03:00
43 changed files with 1970 additions and 407 deletions

View File

@ -1,18 +1,19 @@
from __future__ import annotations
import dataclasses
import inspect
import math
import secrets
from dataclasses import dataclass, field
import inspect
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import math
import einops
import PIL.Image
import numpy as np
import einops
import psutil
import torch
import torchvision.transforms as T
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -27,17 +28,18 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from ..util import normalize_device
@dataclass
@ -292,9 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
):
super().__init__(
vae,
@ -335,12 +335,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return
if self.device.type == "cpu" or self.device.type == "mps":
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.device}")
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@ -363,10 +363,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
@property
def device(self) -> torch.device:
return self.unet.device
def latents_from_embeddings(
self,
latents: torch.Tensor,