diff --git a/.dockerignore b/.dockerignore index 0c177e5912..cfdc7fc735 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,9 +4,9 @@ !ldm !pyproject.toml -# ignore frontend but whitelist dist -invokeai/frontend/ -!invokeai/frontend/dist/ +# ignore frontend/web but whitelist dist +invokeai/frontend/web/ +!invokeai/frontend/web/dist/ # ignore invokeai/assets but whitelist invokeai/assets/web invokeai/assets/ diff --git a/.github/workflows/build-container.yml b/.github/workflows/build-container.yml index 8444c76a61..0fabbdf038 100644 --- a/.github/workflows/build-container.yml +++ b/.github/workflows/build-container.yml @@ -9,16 +9,13 @@ on: - 'dev/docker/*' paths: - 'pyproject.toml' + - '.dockerignore' - 'invokeai/**' - - 'invokeai/backend/**' - - 'invokeai/configs/**' - - 'invokeai/frontend/dist/**' - 'docker/Dockerfile' tags: - 'v*.*.*' workflow_dispatch: - jobs: docker: if: github.event.pull_request.draft == false @@ -56,9 +53,9 @@ jobs: tags: | type=ref,event=branch type=ref,event=tag - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} + type=pep440,pattern={{version}} + type=pep440,pattern={{major}}.{{minor}} + type=pep440,pattern={{major}} type=sha,enable=true,prefix=sha-,format=short flavor: | latest=${{ matrix.flavor == 'cuda' && github.ref == 'refs/heads/main' }} @@ -94,7 +91,7 @@ jobs: context: . file: ${{ env.DOCKERFILE }} platforms: ${{ env.PLATFORMS }} - push: ${{ github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' }} + push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: PIP_EXTRA_INDEX_URL=${{ matrix.pip-extra-index-url }} diff --git a/docker/Dockerfile b/docker/Dockerfile index 2c3320cc0f..1c2b991028 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -63,7 +63,7 @@ RUN --mount=type=cache,target=${PIP_CACHE_DIR} \ # Install requirements COPY --link pyproject.toml . -COPY --link ldm/invoke/_version.py ldm/invoke/__init__.py ldm/invoke/ +COPY --link invokeai/version/invokeai_version.py invokeai/version/__init__.py invokeai/version/ ARG PIP_EXTRA_INDEX_URL ENV PIP_EXTRA_INDEX_URL ${PIP_EXTRA_INDEX_URL} RUN --mount=type=cache,target=${PIP_CACHE_DIR} \ diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py index 8f2992db0c..ee5241bca1 100644 --- a/invokeai/backend/generate.py +++ b/invokeai/backend/generate.py @@ -5,6 +5,7 @@ import gc import importlib +import logging import os import random import re @@ -19,24 +20,20 @@ import numpy as np import skimage import torch import transformers +from PIL import Image, ImageOps +from accelerate.utils import set_seed from diffusers.pipeline_utils import DiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf -from PIL import Image, ImageOps -from pytorch_lightning import logging, seed_everything -from .model_management import ModelManager from .args import metadata_from_png from .generator import infill_methods from .globals import Globals, global_cache_dir from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding +from .model_management import ModelManager from .prompting import get_uc_and_c_and_ec -from .stable_diffusion import ( - DDIMSampler, - HuggingFaceConceptsLibrary, - KSampler, - PLMSSampler, -) +from .prompting.conditioning import log_tokenization +from .stable_diffusion import HuggingFaceConceptsLibrary from .util import choose_precision, choose_torch_device @@ -484,7 +481,7 @@ class Generate: if sampler_name and (sampler_name != self.sampler_name): self.sampler_name = sampler_name - self._set_sampler() + self._set_scheduler() # apply the concepts library to the prompt prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( @@ -493,11 +490,6 @@ class Generate: self.model.textual_inversion_manager.get_all_trigger_strings(), ) - # bit of a hack to change the cached sampler's karras threshold to - # whatever the user asked for - if karras_max is not None and isinstance(self.sampler, KSampler): - self.sampler.adjust_settings(karras_max=karras_max) - tic = time.time() if self._has_cuda(): torch.cuda.reset_peak_memory_stats() @@ -715,7 +707,7 @@ class Generate: prompt, model=self.model, skip_normalize_legacy_blend=opt.skip_normalize, - log_tokens=invokeai.backend.prompting.conditioning.log_tokenization, + log_tokens=log_tokenization, ) if tool in ("gfpgan", "codeformer", "upscale"): @@ -959,7 +951,7 @@ class Generate: # uncache generators so they pick up new models self.generators = {} - seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) + set_seed(random.randrange(0, np.iinfo(np.uint32).max)) if self.embedding_path is not None: print(f">> Loading embeddings from {self.embedding_path}") for root, _, files in os.walk(self.embedding_path): @@ -973,7 +965,7 @@ class Generate: ) self.model_name = model_name - self._set_sampler() # requires self.model_name to be set first + self._set_scheduler() # requires self.model_name to be set first return self.model def load_huggingface_concepts(self, concepts: list[str]): @@ -1105,44 +1097,6 @@ class Generate: def is_legacy_model(self, model_name) -> bool: return self.model_manager.is_legacy(model_name) - def _set_sampler(self): - if isinstance(self.model, DiffusionPipeline): - return self._set_scheduler() - else: - return self._set_sampler_legacy() - - # very repetitive code - can this be simplified? The KSampler names are - # consistent, at least - def _set_sampler_legacy(self): - msg = f">> Setting Sampler to {self.sampler_name}" - if self.sampler_name == "plms": - self.sampler = PLMSSampler(self.model, device=self.device) - elif self.sampler_name == "ddim": - self.sampler = DDIMSampler(self.model, device=self.device) - elif self.sampler_name == "k_dpm_2_a": - self.sampler = KSampler(self.model, "dpm_2_ancestral", device=self.device) - elif self.sampler_name == "k_dpm_2": - self.sampler = KSampler(self.model, "dpm_2", device=self.device) - elif self.sampler_name == "k_dpmpp_2_a": - self.sampler = KSampler( - self.model, "dpmpp_2s_ancestral", device=self.device - ) - elif self.sampler_name == "k_dpmpp_2": - self.sampler = KSampler(self.model, "dpmpp_2m", device=self.device) - elif self.sampler_name == "k_euler_a": - self.sampler = KSampler(self.model, "euler_ancestral", device=self.device) - elif self.sampler_name == "k_euler": - self.sampler = KSampler(self.model, "euler", device=self.device) - elif self.sampler_name == "k_heun": - self.sampler = KSampler(self.model, "heun", device=self.device) - elif self.sampler_name == "k_lms": - self.sampler = KSampler(self.model, "lms", device=self.device) - else: - msg = f">> Unsupported Sampler: {self.sampler_name}, Defaulting to plms" - self.sampler = PLMSSampler(self.model, device=self.device) - - print(msg) - def _set_scheduler(self): default = self.model.scheduler diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 831c941ff4..a834e9dba3 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -5,7 +5,6 @@ including img2img, txt2img, and inpaint from __future__ import annotations import os -import os.path as osp import random import traceback from contextlib import nullcontext @@ -14,15 +13,12 @@ from pathlib import Path import cv2 import numpy as np import torch -from diffusers import DiffusionPipeline -from einops import rearrange from PIL import Image, ImageChops, ImageFilter -from pytorch_lightning import seed_everything +from accelerate.utils import set_seed +from diffusers import DiffusionPipeline from tqdm import trange import invokeai.assets.web as web_assets - -from ..stable_diffusion.diffusion.ddpm import DiffusionWrapper from ..util.util import rand_perlin_2d downsampling = 8 @@ -33,9 +29,9 @@ class Generator: downsampling_factor: int latent_channels: int precision: str - model: DiffusionWrapper | DiffusionPipeline + model: DiffusionPipeline - def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str): + def __init__(self, model: DiffusionPipeline, precision: str): self.model = model self.precision = precision self.seed = None @@ -116,14 +112,14 @@ class Generator: for n in trange(iterations, desc="Generating"): x_T = None if self.variation_amount > 0: - seed_everything(seed) + set_seed(seed) target_noise = self.get_noise(width, height) x_T = self.slerp(self.variation_amount, initial_noise, target_noise) elif initial_noise is not None: # i.e. we specified particular variations x_T = initial_noise else: - seed_everything(seed) + set_seed(seed) try: x_T = self.get_noise(width, height) except: @@ -283,11 +279,11 @@ class Generator: initial_noise = None if self.variation_amount > 0 or len(self.with_variations) > 0: # use fixed initial noise plus random noise per iteration - seed_everything(seed) + set_seed(seed) initial_noise = self.get_noise(width, height) for v_seed, v_weight in self.with_variations: seed = v_seed - seed_everything(seed) + set_seed(seed) next_noise = self.get_noise(width, height) initial_noise = self.slerp(v_weight, initial_noise, next_noise) if self.variation_amount > 0: diff --git a/invokeai/backend/restoration/outcrop.py b/invokeai/backend/restoration/outcrop.py index 0778d7cc8f..e0f110f71e 100644 --- a/invokeai/backend/restoration/outcrop.py +++ b/invokeai/backend/restoration/outcrop.py @@ -1,7 +1,6 @@ import math -import warnings -from PIL import Image, ImageFilter +from PIL import Image class Outcrop(object): @@ -27,7 +26,7 @@ class Outcrop(object): # switch samplers temporarily curr_sampler = self.generate.sampler self.generate.sampler_name = opt.sampler_name - self.generate._set_sampler() + self.generate._set_scheduler() def wrapped_callback(img, seed, **kwargs): preferred_seed = ( diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 05886f7b10..55333d3589 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -9,8 +9,5 @@ from .diffusers_pipeline import ( ) from .diffusion import InvokeAIDiffuserComponent from .diffusion.cross_attention_map_saving import AttentionMapSaver -from .diffusion.ddim import DDIMSampler -from .diffusion.ksampler import KSampler -from .diffusion.plms import PLMSSampler from .diffusion.shared_invokeai_diffusion import PostprocessingSettings from .textual_inversion_manager import TextualInversionManager diff --git a/invokeai/backend/stable_diffusion/attention.py b/invokeai/backend/stable_diffusion/attention.py deleted file mode 100644 index 484b42c0bd..0000000000 --- a/invokeai/backend/stable_diffusion/attention.py +++ /dev/null @@ -1,290 +0,0 @@ -import math -from inspect import isfunction -from typing import Callable, Optional - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import einsum, nn - -from .diffusion import InvokeAICrossAttentionMixin -from .diffusionmodules.util import checkpoint - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -def get_mem_free_total(device): - # only on cuda - if not torch.cuda.is_available(): - return None - stats = torch.cuda.memory_stats(device) - mem_active = stats["active_bytes.all.current"] - mem_reserved = stats["reserved_bytes.all.current"] - mem_free_cuda, _ = torch.cuda.mem_get_info(device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - return mem_free_total - - -class CrossAttention(nn.Module, InvokeAICrossAttentionMixin): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - InvokeAICrossAttentionMixin.__init__(self) - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) * self.scale - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - # don't apply scale twice - cached_scale = self.scale - self.scale = 1 - r = self.get_invokeai_attention_mem_efficient(q, k, v) - self.scale = cached_scale - - hidden_states = rearrange(r, "(b h) n d -> b n (h d)", h=h) - return self.to_out(hidden_states) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None): - x = x.contiguous() if x.device.type == "mps" else x - x += self.attn1(self.norm1(x.clone())) - x += self.attn2(self.norm2(x.clone()), context=context) - x += self.ff(self.norm3(x.clone())) - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - """ - - def __init__( - self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None - ): - super().__init__() - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim - ) - for d in range(depth) - ] - ) - - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") - for block in self.transformer_blocks: - x = block(x, context=context) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) - return x + x_in diff --git a/invokeai/backend/stable_diffusion/autoencoder.py b/invokeai/backend/stable_diffusion/autoencoder.py deleted file mode 100644 index 2bc7fa84f6..0000000000 --- a/invokeai/backend/stable_diffusion/autoencoder.py +++ /dev/null @@ -1,565 +0,0 @@ -from contextlib import contextmanager - -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer - -from ..util import instantiate_from_config -from .diffusionmodules.model import Decoder, Encoder -from .distributions.distributions import DiagonalGaussianDistribution - - -class VQModel(pl.LightningModule): - def __init__( - self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False, - ): - super().__init__() - self.embed_dim = embed_dim - self.n_embed = n_embed - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer( - n_embed, - embed_dim, - beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape, - ) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - self.batch_resize_range = batch_resize_range - if self.batch_resize_range is not None: - print( - f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}." - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self) - print(f">> Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.scheduler_config = scheduler_config - self.lr_g_factor = lr_g_factor - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - print(f"Unexpected Keys: {unexpected}") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input, return_pred_indices=False): - quant, diff, (_, _, ind) = self.encode(input) - dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - if self.batch_resize_range is not None: - lower_size = self.batch_resize_range[0] - upper_size = self.batch_resize_range[1] - if self.global_step <= 4: - # do the first few batches with max size to avoid later oom - new_resize = upper_size - else: - new_resize = np.random.choice( - np.arange(lower_size, upper_size + 16, 16) - ) - if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") - x = x.detach() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - # https://github.com/pytorch/pytorch/issues/37142 - # try not to fool the heuristics - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss( - qloss, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - predicted_indices=ind, - ) - - self.log_dict( - log_dict_ae, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=True, - ) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss( - qloss, - x, - xrec, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - self.log_dict( - log_dict_disc, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=True, - ) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, suffix=""): - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss( - qloss, - x, - xrec, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + suffix, - predicted_indices=ind, - ) - - discloss, log_dict_disc = self.loss( - qloss, - x, - xrec, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val" + suffix, - predicted_indices=ind, - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log( - f"val{suffix}/rec_loss", - rec_loss, - prog_bar=True, - logger=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - self.log( - f"val{suffix}/aeloss", - aeloss, - prog_bar=True, - logger=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - if version.parse(pl.__version__) >= version.parse("1.4.0"): - del log_dict_ae[f"val{suffix}/rec_loss"] - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr_d = self.learning_rate - lr_g = self.lr_g_factor * self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.quantize.parameters()) - + list(self.quant_conv.parameters()) - + list(self.post_quant_conv.parameters()), - lr=lr_g, - betas=(0.5, 0.9), - ) - opt_disc = torch.optim.Adam( - self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9) - ) - - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - }, - { - "scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - }, - ] - return [opt_ae, opt_disc], scheduler - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if only_inputs: - log["inputs"] = x - return log - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - if plot_ema: - with self.ema_scope(): - xrec_ema, _ = self(x) - if x.shape[1] > 3: - xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class VQModelInterface(VQModel): - def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) - self.embed_dim = embed_dim - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class AutoencoderKL(pl.LightningModule): - def __init__( - self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ): - super().__init__() - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - self.log( - "aeloss", - aeloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_ae, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=False, - ) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - optimizer_idx, - self.global_step, - last_layer=self.get_last_layer(), - split="train", - ) - - self.log( - "discloss", - discloss, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - self.log_dict( - log_dict_disc, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=False, - ) - return discloss - - def validation_step(self, batch, batch_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss( - inputs, - reconstructions, - posterior, - 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val", - ) - - discloss, log_dict_disc = self.loss( - inputs, - reconstructions, - posterior, - 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val", - ) - - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - opt_ae = torch.optim.Adam( - list(self.encoder.parameters()) - + list(self.decoder.parameters()) - + list(self.quant_conv.parameters()) - + list(self.post_quant_conv.parameters()), - lr=lr, - betas=(0.5, 0.9), - ) - opt_disc = torch.optim.Adam( - self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) - ) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - log["inputs"] = x - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x diff --git a/invokeai/backend/stable_diffusion/data/__init__.py b/invokeai/backend/stable_diffusion/data/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/stable_diffusion/data/base.py b/invokeai/backend/stable_diffusion/data/base.py deleted file mode 100644 index 1b6a138bf7..0000000000 --- a/invokeai/backend/stable_diffusion/data/base.py +++ /dev/null @@ -1,25 +0,0 @@ -from abc import abstractmethod - -from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset - - -class Txt2ImgIterableBaseDataset(IterableDataset): - """ - Define an interface to make the IterableDatasets for text2img data chainable - """ - - def __init__(self, num_records=0, valid_ids=None, size=256): - super().__init__() - self.num_records = num_records - self.valid_ids = valid_ids - self.sample_ids = valid_ids - self.size = size - - print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") - - def __len__(self): - return self.num_records - - @abstractmethod - def __iter__(self): - pass diff --git a/invokeai/backend/stable_diffusion/data/imagenet.py b/invokeai/backend/stable_diffusion/data/imagenet.py deleted file mode 100644 index 84bad27590..0000000000 --- a/invokeai/backend/stable_diffusion/data/imagenet.py +++ /dev/null @@ -1,453 +0,0 @@ -import glob -import os -import pickle -import shutil -import tarfile -from functools import partial - -import albumentations -import cv2 -import numpy as np -import PIL -import taming.data.utils as tdu -import torchvision.transforms.functional as TF -import yaml -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light -from omegaconf import OmegaConf -from PIL import Image -from taming.data.imagenet import ( - ImagePaths, - download, - give_synsets_from_indices, - retrieve, - str_to_indices, -) -from torch.utils.data import Dataset, Subset -from tqdm import tqdm - - -def synset2idx(path_to_yaml="data/index_synset.yaml"): - with open(path_to_yaml) as f: - di2s = yaml.load(f) - return dict((v, k) for k, v in di2s.items()) - - -class ImageNetBase(Dataset): - def __init__(self, config=None): - self.config = config or OmegaConf.create() - if not type(self.config) == dict: - self.config = OmegaConf.to_container(self.config) - self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) - self.process_images = True # if False we skip loading & processing images and self.data contains filepaths - self._prepare() - self._prepare_synset_to_human() - self._prepare_idx_to_synset() - self._prepare_human_to_integer_label() - self._load() - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - def _prepare(self): - raise NotImplementedError() - - def _filter_relpaths(self, relpaths): - ignore = set( - [ - "n06596364_9591.JPEG", - ] - ) - relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] - if "sub_indices" in self.config: - indices = str_to_indices(self.config["sub_indices"]) - synsets = give_synsets_from_indices( - indices, path_to_yaml=self.idx2syn - ) # returns a list of strings - self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) - files = [] - for rpath in relpaths: - syn = rpath.split("/")[0] - if syn in synsets: - files.append(rpath) - return files - else: - return relpaths - - def _prepare_synset_to_human(self): - SIZE = 2655750 - URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" - self.human_dict = os.path.join(self.root, "synset_human.txt") - if ( - not os.path.exists(self.human_dict) - or not os.path.getsize(self.human_dict) == SIZE - ): - download(URL, self.human_dict) - - def _prepare_idx_to_synset(self): - URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" - self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if not os.path.exists(self.idx2syn): - download(URL, self.idx2syn) - - def _prepare_human_to_integer_label(self): - URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" - self.human2integer = os.path.join( - self.root, "imagenet1000_clsidx_to_labels.txt" - ) - if not os.path.exists(self.human2integer): - download(URL, self.human2integer) - with open(self.human2integer, "r") as f: - lines = f.read().splitlines() - assert len(lines) == 1000 - self.human2integer_dict = dict() - for line in lines: - value, key = line.split(":") - self.human2integer_dict[key] = int(value) - - def _load(self): - with open(self.txt_filelist, "r") as f: - self.relpaths = f.read().splitlines() - l1 = len(self.relpaths) - self.relpaths = self._filter_relpaths(self.relpaths) - print( - "Removed {} files from filelist during filtering.".format( - l1 - len(self.relpaths) - ) - ) - - self.synsets = [p.split("/")[0] for p in self.relpaths] - self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] - - unique_synsets = np.unique(self.synsets) - class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) - if not self.keep_orig_class_label: - self.class_labels = [class_dict[s] for s in self.synsets] - else: - self.class_labels = [self.synset2idx[s] for s in self.synsets] - - with open(self.human_dict, "r") as f: - human_dict = f.read().splitlines() - human_dict = dict(line.split(maxsplit=1) for line in human_dict) - - self.human_labels = [human_dict[s] for s in self.synsets] - - labels = { - "relpath": np.array(self.relpaths), - "synsets": np.array(self.synsets), - "class_label": np.array(self.class_labels), - "human_label": np.array(self.human_labels), - } - - if self.process_images: - self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths( - self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) - else: - self.data = self.abspaths - - -class ImageNetTrain(ImageNetBase): - NAME = "ILSVRC2012_train" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" - FILES = [ - "ILSVRC2012_img_train.tar", - ] - SIZES = [ - 147897477120, - ] - - def __init__(self, process_images=True, data_root=None, **kwargs): - self.process_images = process_images - self.data_root = data_root - super().__init__(**kwargs) - - def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 1281167 - self.random_crop = retrieve( - self.config, "ImageNetTrain/random_crop", default=True - ) - if not tdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if ( - not os.path.exists(path) - or not os.path.getsize(path) == self.SIZES[0] - ): - import academictorrents as at - - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - print("Extracting sub-tars.") - subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) - for subpath in tqdm(subpaths): - subdir = subpath[: -len(".tar")] - os.makedirs(subdir, exist_ok=True) - with tarfile.open(subpath, "r:") as tar: - tar.extractall(path=subdir) - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist) + "\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - tdu.mark_prepared(self.root) - - -class ImageNetValidation(ImageNetBase): - NAME = "ILSVRC2012_validation" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" - VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" - FILES = [ - "ILSVRC2012_img_val.tar", - "validation_synset.txt", - ] - SIZES = [ - 6744924160, - 1950000, - ] - - def __init__(self, process_images=True, data_root=None, **kwargs): - self.data_root = data_root - self.process_images = process_images - super().__init__(**kwargs) - - def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 50000 - self.random_crop = retrieve( - self.config, "ImageNetValidation/random_crop", default=False - ) - if not tdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if ( - not os.path.exists(path) - or not os.path.getsize(path) == self.SIZES[0] - ): - import academictorrents as at - - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - vspath = os.path.join(self.root, self.FILES[1]) - if ( - not os.path.exists(vspath) - or not os.path.getsize(vspath) == self.SIZES[1] - ): - download(self.VS_URL, vspath) - - with open(vspath, "r") as f: - synset_dict = f.read().splitlines() - synset_dict = dict(line.split() for line in synset_dict) - - print("Reorganizing into synset folders") - synsets = np.unique(list(synset_dict.values())) - for s in synsets: - os.makedirs(os.path.join(datadir, s), exist_ok=True) - for k, v in synset_dict.items(): - src = os.path.join(datadir, k) - dst = os.path.join(datadir, v) - shutil.move(src, dst) - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist) + "\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - tdu.mark_prepared(self.root) - - -class ImageNetSR(Dataset): - def __init__( - self, - size=None, - degradation=None, - downscale_f=4, - min_crop_f=0.5, - max_crop_f=1.0, - random_crop=True, - ): - """ - Imagenet Superresolution Dataloader - Performs following ops in order: - 1. crops a crop of size s from image either as random or center crop - 2. resizes crop to size with cv2.area_interpolation - 3. degrades resized crop with degradation_fn - - :param size: resizing to size after cropping - :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light - :param downscale_f: Low Resolution Downsample factor - :param min_crop_f: determines crop size s, - where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) - :param max_crop_f: "" - :param data_root: - :param random_crop: - """ - self.base = self.get_base() - assert size - assert (size / downscale_f).is_integer() - self.size = size - self.LR_size = int(size / downscale_f) - self.min_crop_f = min_crop_f - self.max_crop_f = max_crop_f - assert max_crop_f <= 1.0 - self.center_crop = not random_crop - - self.image_rescaler = albumentations.SmallestMaxSize( - max_size=size, interpolation=cv2.INTER_AREA - ) - - self.pil_interpolation = ( - False # gets reset later if incase interp_op is from pillow - ) - - if degradation == "bsrgan": - self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) - - elif degradation == "bsrgan_light": - self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) - - else: - interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, - }[degradation] - - self.pil_interpolation = degradation.startswith("pil_") - - if self.pil_interpolation: - self.degradation_process = partial( - TF.resize, - size=self.LR_size, - interpolation=interpolation_fn, - ) - - else: - self.degradation_process = albumentations.SmallestMaxSize( - max_size=self.LR_size, interpolation=interpolation_fn - ) - - def __len__(self): - return len(self.base) - - def __getitem__(self, i): - example = self.base[i] - image = Image.open(example["file_path_"]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - image = np.array(image).astype(np.uint8) - - min_side_len = min(image.shape[:2]) - crop_side_len = min_side_len * np.random.uniform( - self.min_crop_f, self.max_crop_f, size=None - ) - crop_side_len = int(crop_side_len) - - if self.center_crop: - self.cropper = albumentations.CenterCrop( - height=crop_side_len, width=crop_side_len - ) - - else: - self.cropper = albumentations.RandomCrop( - height=crop_side_len, width=crop_side_len - ) - - image = self.cropper(image=image)["image"] - image = self.image_rescaler(image=image)["image"] - - if self.pil_interpolation: - image_pil = PIL.Image.fromarray(image) - LR_image = self.degradation_process(image_pil) - LR_image = np.array(LR_image).astype(np.uint8) - - else: - LR_image = self.degradation_process(image=image)["image"] - - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) - - return example - - -class ImageNetSRTrain(ImageNetSR): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_base(self): - with open("data/imagenet_train_hr_indices.p", "rb") as f: - indices = pickle.load(f) - dset = ImageNetTrain( - process_images=False, - ) - return Subset(dset, indices) - - -class ImageNetSRValidation(ImageNetSR): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_base(self): - with open("data/imagenet_val_hr_indices.p", "rb") as f: - indices = pickle.load(f) - dset = ImageNetValidation( - process_images=False, - ) - return Subset(dset, indices) diff --git a/invokeai/backend/stable_diffusion/data/lsun.py b/invokeai/backend/stable_diffusion/data/lsun.py deleted file mode 100644 index e9c2543f10..0000000000 --- a/invokeai/backend/stable_diffusion/data/lsun.py +++ /dev/null @@ -1,124 +0,0 @@ -import os - -import numpy as np -import PIL -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms - - -class LSUNBase(Dataset): - def __init__( - self, - txt_file, - data_root, - size=None, - interpolation="bicubic", - flip_p=0.5, - ): - self.data_paths = txt_file - self.data_root = data_root - with open(self.data_paths, "r") as f: - self.image_paths = f.read().splitlines() - self._length = len(self.image_paths) - self.labels = { - "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], - } - - self.size = size - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = dict((k, self.labels[k][i]) for k in self.labels) - image = Image.open(example["file_path_"]) - if not image.mode == "RGB": - image = image.convert("RGB") - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) - - image = self.flip(image) - image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example - - -class LSUNChurchesTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__( - txt_file="data/lsun/church_outdoor_train.txt", - data_root="data/lsun/churches", - **kwargs, - ) - - -class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0.0, **kwargs): - super().__init__( - txt_file="data/lsun/church_outdoor_val.txt", - data_root="data/lsun/churches", - flip_p=flip_p, - **kwargs, - ) - - -class LSUNBedroomsTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__( - txt_file="data/lsun/bedrooms_train.txt", - data_root="data/lsun/bedrooms", - **kwargs, - ) - - -class LSUNBedroomsValidation(LSUNBase): - def __init__(self, flip_p=0.0, **kwargs): - super().__init__( - txt_file="data/lsun/bedrooms_val.txt", - data_root="data/lsun/bedrooms", - flip_p=flip_p, - **kwargs, - ) - - -class LSUNCatsTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__( - txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs - ) - - -class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0.0, **kwargs): - super().__init__( - txt_file="data/lsun/cat_val.txt", - data_root="data/lsun/cats", - flip_p=flip_p, - **kwargs, - ) diff --git a/invokeai/backend/stable_diffusion/data/personalized.py b/invokeai/backend/stable_diffusion/data/personalized.py deleted file mode 100644 index fc8297a68a..0000000000 --- a/invokeai/backend/stable_diffusion/data/personalized.py +++ /dev/null @@ -1,199 +0,0 @@ -import os -import random - -import numpy as np -import PIL -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms - -imagenet_templates_smallest = [ - "a photo of a {}", -] - -imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -imagenet_dual_templates_small = [ - "a photo of a {} with {}", - "a rendering of a {} with {}", - "a cropped photo of the {} with {}", - "the photo of a {} with {}", - "a photo of a clean {} with {}", - "a photo of a dirty {} with {}", - "a dark photo of the {} with {}", - "a photo of my {} with {}", - "a photo of the cool {} with {}", - "a close-up photo of a {} with {}", - "a bright photo of the {} with {}", - "a cropped photo of a {} with {}", - "a photo of the {} with {}", - "a good photo of the {} with {}", - "a photo of one {} with {}", - "a close-up photo of the {} with {}", - "a rendition of the {} with {}", - "a photo of the clean {} with {}", - "a rendition of a {} with {}", - "a photo of a nice {} with {}", - "a good photo of a {} with {}", - "a photo of the nice {} with {}", - "a photo of the small {} with {}", - "a photo of the weird {} with {}", - "a photo of the large {} with {}", - "a photo of a cool {} with {}", - "a photo of a small {} with {}", -] - -per_img_token_list = [ - "א", - "ב", - "ג", - "ד", - "ה", - "ו", - "ז", - "ח", - "ט", - "י", - "כ", - "ל", - "מ", - "נ", - "ס", - "ע", - "פ", - "צ", - "ק", - "ר", - "ש", - "ת", -] - - -class PersonalizedBase(Dataset): - def __init__( - self, - data_root, - size=None, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - per_image_tokens=False, - center_crop=False, - mixing_prob=0.25, - coarse_class_text=None, - ): - self.data_root = data_root - - self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - if file_path != ".DS_Store" - ] - - # self._length = len(self.image_paths) - self.num_images = len(self.image_paths) - self._length = self.num_images - - self.placeholder_token = placeholder_token - - self.per_image_tokens = per_image_tokens - self.center_crop = center_crop - self.mixing_prob = mixing_prob - - self.coarse_class_text = coarse_class_text - - if per_image_tokens: - assert self.num_images < len( - per_img_token_list - ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." - - if set == "train": - self._length = self.num_images * repeats - - self.size = size - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - placeholder_string = self.placeholder_token - if self.coarse_class_text: - placeholder_string = f"{self.coarse_class_text} {placeholder_string}" - - if self.per_image_tokens and np.random.uniform() < self.mixing_prob: - text = random.choice(imagenet_dual_templates_small).format( - placeholder_string, per_img_token_list[i % self.num_images] - ) - else: - text = random.choice(imagenet_templates_small).format(placeholder_string) - - example["caption"] = text - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) - - image = self.flip(image) - image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example diff --git a/invokeai/backend/stable_diffusion/data/personalized_style.py b/invokeai/backend/stable_diffusion/data/personalized_style.py deleted file mode 100644 index 246c25e930..0000000000 --- a/invokeai/backend/stable_diffusion/data/personalized_style.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import random - -import numpy as np -import PIL -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms - -imagenet_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - -imagenet_dual_templates_small = [ - "a painting in the style of {} with {}", - "a rendering in the style of {} with {}", - "a cropped painting in the style of {} with {}", - "the painting in the style of {} with {}", - "a clean painting in the style of {} with {}", - "a dirty painting in the style of {} with {}", - "a dark painting in the style of {} with {}", - "a cool painting in the style of {} with {}", - "a close-up painting in the style of {} with {}", - "a bright painting in the style of {} with {}", - "a cropped painting in the style of {} with {}", - "a good painting in the style of {} with {}", - "a painting of one {} in the style of {}", - "a nice painting in the style of {} with {}", - "a small painting in the style of {} with {}", - "a weird painting in the style of {} with {}", - "a large painting in the style of {} with {}", -] - -per_img_token_list = [ - "א", - "ב", - "ג", - "ד", - "ה", - "ו", - "ז", - "ח", - "ט", - "י", - "כ", - "ל", - "מ", - "נ", - "ס", - "ע", - "פ", - "צ", - "ק", - "ר", - "ש", - "ת", -] - - -class PersonalizedBase(Dataset): - def __init__( - self, - data_root, - size=None, - repeats=100, - interpolation="bicubic", - flip_p=0.5, - set="train", - placeholder_token="*", - per_image_tokens=False, - center_crop=False, - ): - self.data_root = data_root - - self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - if file_path != ".DS_Store" - ] - - # self._length = len(self.image_paths) - self.num_images = len(self.image_paths) - self._length = self.num_images - - self.placeholder_token = placeholder_token - - self.per_image_tokens = per_image_tokens - self.center_crop = center_crop - - if per_image_tokens: - assert self.num_images < len( - per_img_token_list - ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." - - if set == "train": - self._length = self.num_images * repeats - - self.size = size - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - if self.per_image_tokens and np.random.uniform() < 0.25: - text = random.choice(imagenet_dual_templates_small).format( - self.placeholder_token, per_img_token_list[i % self.num_images] - ) - else: - text = random.choice(imagenet_templates_small).format( - self.placeholder_token - ) - - example["caption"] = text - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) - - image = self.flip(image) - image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example diff --git a/invokeai/backend/stable_diffusion/diffusion/classifier.py b/invokeai/backend/stable_diffusion/diffusion/classifier.py deleted file mode 100644 index 89aba16ee9..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/classifier.py +++ /dev/null @@ -1,330 +0,0 @@ -import os -from copy import deepcopy -from glob import glob - -import pytorch_lightning as pl -import torch -from einops import rearrange -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img -from natsort import natsorted -from omegaconf import OmegaConf -from torch.nn import functional as F -from torch.optim import AdamW -from torch.optim.lr_scheduler import LambdaLR - -__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class NoisyLatentImageClassifier(pl.LightningModule): - def __init__( - self, - diffusion_path, - num_classes, - ckpt_path=None, - pool="attention", - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.0e-2, - log_steps=10, - monitor="val/loss", - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.num_classes = num_classes - # get latest config of diffusion model - diffusion_config = natsorted( - glob(os.path.join(diffusion_path, "configs", "*-project.yaml")) - )[-1] - self.diffusion_config = OmegaConf.load(diffusion_config).model - self.diffusion_config.params.ckpt_path = diffusion_ckpt_path - self.load_diffusion() - - self.monitor = monitor - self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 - self.log_time_interval = self.diffusion_model.num_timesteps // log_steps - self.log_steps = log_steps - - self.label_key = ( - label_key - if not hasattr(self.diffusion_model, "cond_stage_key") - else self.diffusion_model.cond_stage_key - ) - - assert ( - self.label_key is not None - ), "label_key neither in diffusion model nor in model.params" - - if self.label_key not in __models__: - raise NotImplementedError() - - self.load_classifier(ckpt_path, pool) - - self.scheduler_config = scheduler_config - self.use_scheduler = self.scheduler_config is not None - self.weight_decay = weight_decay - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) - ) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def load_diffusion(self): - model = instantiate_from_config(self.diffusion_config) - self.diffusion_model = model.eval() - self.diffusion_model.train = disabled_train - for param in self.diffusion_model.parameters(): - param.requires_grad = False - - def load_classifier(self, ckpt_path, pool): - model_config = deepcopy(self.diffusion_config.params.unet_config.params) - model_config.in_channels = ( - self.diffusion_config.params.unet_config.params.out_channels - ) - model_config.out_channels = self.num_classes - if self.label_key == "class_label": - model_config.pool = pool - - self.model = __models__[self.label_key](**model_config) - if ckpt_path is not None: - print( - "#####################################################################" - ) - print(f'load from ckpt "{ckpt_path}"') - print( - "#####################################################################" - ) - self.init_from_ckpt(ckpt_path) - - @torch.no_grad() - def get_x_noisy(self, x, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x)) - continuous_sqrt_alpha_cumprod = None - if self.diffusion_model.use_continuous_noise: - continuous_sqrt_alpha_cumprod = ( - self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) - ) - # todo: make sure t+1 is correct here - - return self.diffusion_model.q_sample( - x_start=x, - t=t, - noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod, - ) - - def forward(self, x_noisy, t, *args, **kwargs): - return self.model(x_noisy, t) - - @torch.no_grad() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, "b h w c -> b c h w") - x = x.to(memory_format=torch.contiguous_format).float() - return x - - @torch.no_grad() - def get_conditioning(self, batch, k=None): - if k is None: - k = self.label_key - assert k is not None, "Needs to provide label key" - - targets = batch[k].to(self.device) - - if self.label_key == "segmentation": - targets = rearrange(targets, "b h w c -> b c h w") - for down in range(self.numd): - h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest") - - # targets = rearrange(targets,'b c h w -> b h w c') - - return targets - - def compute_top_k(self, logits, labels, k, reduction="mean"): - _, top_ks = torch.topk(logits, k, dim=1) - if reduction == "mean": - return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() - elif reduction == "none": - return (top_ks == labels[:, None]).float().sum(dim=-1) - - def on_train_epoch_start(self): - # save some memory - self.diffusion_model.model.to("cpu") - - @torch.no_grad() - def write_logs(self, loss, logits, targets): - log_prefix = "train" if self.training else "val" - log = {} - log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) - - self.log_dict( - log, - prog_bar=False, - logger=True, - on_step=self.training, - on_epoch=True, - ) - self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log( - "global_step", - self.global_step, - logger=False, - on_epoch=False, - prog_bar=True, - ) - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", - lr, - on_step=True, - logger=True, - on_epoch=False, - prog_bar=True, - ) - - def shared_step(self, batch, t=None): - x, *_ = self.diffusion_model.get_input( - batch, k=self.diffusion_model.first_stage_key - ) - targets = self.get_conditioning(batch) - if targets.dim() == 4: - targets = targets.argmax(dim=1) - if t is None: - t = torch.randint( - 0, - self.diffusion_model.num_timesteps, - (x.shape[0],), - device=self.device, - ).long() - else: - t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() - x_noisy = self.get_x_noisy(x, t) - logits = self(x_noisy, t) - - loss = F.cross_entropy(logits, targets, reduction="none") - - self.write_logs(loss.detach(), logits.detach(), targets.detach()) - - loss = loss.mean() - return loss, logits, x_noisy, targets - - def training_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - return loss - - def reset_noise_accs(self): - self.noisy_acc = { - t: {"acc@1": [], "acc@5": []} - for t in range( - 0, - self.diffusion_model.num_timesteps, - self.diffusion_model.log_every_t, - ) - } - - def on_validation_start(self): - self.reset_noise_accs() - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - - for t in self.noisy_acc: - _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]["acc@1"].append( - self.compute_top_k(logits, targets, k=1, reduction="mean") - ) - self.noisy_acc[t]["acc@5"].append( - self.compute_top_k(logits, targets, k=5, reduction="mean") - ) - - return loss - - def configure_optimizers(self): - optimizer = AdamW( - self.model.parameters(), - lr=self.learning_rate, - weight_decay=self.weight_decay, - ) - - if self.use_scheduler: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [optimizer], scheduler - - return optimizer - - @torch.no_grad() - def log_images(self, batch, N=8, *args, **kwargs): - log = dict() - x = self.get_input(batch, self.diffusion_model.first_stage_key) - log["inputs"] = x - - y = self.get_conditioning(batch) - - if self.label_key == "class_label": - y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log["labels"] = y - - if ismap(y): - log["labels"] = self.diffusion_model.to_rgb(y) - - for step in range(self.log_steps): - current_time = step * self.log_time_interval - - _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - - log[f"inputs@t{current_time}"] = x_noisy - - pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) - pred = rearrange(pred, "b h w c -> b c h w") - - log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred) - - for key in log: - log[key] = log[key][:N] - - return log diff --git a/invokeai/backend/stable_diffusion/diffusion/ddim.py b/invokeai/backend/stable_diffusion/diffusion/ddim.py deleted file mode 100644 index 87f6f2166b..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/ddim.py +++ /dev/null @@ -1,113 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch - -from ..diffusionmodules.util import noise_like -from .sampler import Sampler -from .shared_invokeai_diffusion import InvokeAIDiffuserComponent - - -class DDIMSampler(Sampler): - def __init__(self, model, schedule="linear", device=None, **kwargs): - super().__init__(model, schedule, model.num_timesteps, device) - - self.invokeai_diffuser = InvokeAIDiffuserComponent( - self.model, - model_forward_callback=lambda x, sigma, cond: self.model.apply_model( - x, sigma, cond - ), - ) - - def prepare_to_sample(self, t_enc, **kwargs): - super().prepare_to_sample(t_enc, **kwargs) - - extra_conditioning_info = kwargs.get("extra_conditioning_info", None) - all_timesteps_count = kwargs.get("all_timesteps_count", t_enc) - - if ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ): - self.invokeai_diffuser.override_cross_attention( - extra_conditioning_info, step_count=all_timesteps_count - ) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - # This is the central routine - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - step_count: int = 1000, # total number of steps - **kwargs, - ): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - # damian0815 would like to know when/if this code path is used - e_t = self.model.apply_model(x, t, c) - else: - # step_index counts in the opposite direction to index - step_index = step_count - (index + 1) - e_t = self.invokeai_diffuser.do_diffusion_step( - x, - t, - unconditional_conditioning, - c, - unconditional_guidance_scale, - step_index=step_index, - ) - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0, None diff --git a/invokeai/backend/stable_diffusion/diffusion/ddpm.py b/invokeai/backend/stable_diffusion/diffusion/ddpm.py deleted file mode 100644 index 6741498303..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/ddpm.py +++ /dev/null @@ -1,2125 +0,0 @@ -""" -wild mixture of -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers --- merci -""" - -import os -import urllib -from contextlib import contextmanager -from functools import partial - -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn as nn -from einops import rearrange, repeat -from omegaconf import ListConfig -from pytorch_lightning.utilities.distributed import rank_zero_only -from torch.optim.lr_scheduler import LambdaLR -from torchvision.utils import make_grid -from tqdm import tqdm - -from ...util.util import ( - count_params, - default, - exists, - instantiate_from_config, - isimage, - ismap, - log_txt_as_img, - mean_flat, -) -from ..autoencoder import AutoencoderKL, IdentityFirstStage, VQModelInterface -from ..diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like -from ..distributions.distributions import DiagonalGaussianDistribution, normal_kl -from ..ema import LitEma -from ..textual_inversion_manager import TextualInversionManager -from .ddim import DDIMSampler - -__conditioning_keys__ = { - "concat": "c_concat", - "crossattn": "c_crossattn", - "adm": "y", -} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DDPM(pl.LightningModule): - # classic DDPM with Gaussian diffusion, in image space - def __init__( - self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0.0, - embedding_reg_weight=0.0, - v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1.0, - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0.0, - ): - super().__init__() - assert parameterization in [ - "eps", - "x0", - ], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - print( - f" | {self.__class__.__name__}: Running in {self.parameterization}-prediction mode" - ) - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - print(f" | Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - self.embedding_reg_weight = embedding_reg_weight - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt( - ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet - ) - - self.register_schedule( - given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod)), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", - to_torch(np.log(1.0 - alphas_cumprod)), - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod)), - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod - 1)), - ) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev - ) / (1.0 - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer("posterior_variance", to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer( - "posterior_log_variance_clipped", - to_torch(np.log(np.maximum(posterior_variance, 1e-20))), - ) - self.register_buffer( - "posterior_mean_coef1", - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), - ) - self.register_buffer( - "posterior_mean_coef2", - to_torch( - (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) - ), - ) - - if self.parameterization == "eps": - lvlb_weights = self.betas**2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) - ) - elif self.parameterization == "x0": - lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) - ) - else: - raise NotImplementedError("mu not supported") - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) - ) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) - return ( - posterior_mean, - posterior_variance, - posterior_log_variance_clipped, - ) - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - - ( - model_mean, - posterior_variance, - posterior_log_variance, - ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance( - x=x, t=t, clip_denoised=clip_denoised - ) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm( - reversed(range(0, self.num_timesteps)), - desc="Sampling t", - total=self.num_timesteps, - dynamic_ncols=True, - ): - img = self.p_sample( - img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised, - ) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop( - (batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates, - ) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == "l1": - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == "l2": - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction="none") - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def p_losses(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.model(x_noisy, t) - - loss_dict = {} - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start - else: - raise NotImplementedError( - f"Paramterization {self.parameterization} not yet supported" - ) - - loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - - log_prefix = "train" if self.training else "val" - - loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) - loss_simple = loss.mean() * self.l_simple_weight - - loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) - - loss = loss_simple + self.original_elbo_weight * loss_vlb - - loss_dict.update({f"{log_prefix}/loss": loss}) - - return loss, loss_dict - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, "b h w c -> b c h w") - x = x.to(memory_format=torch.contiguous_format).float() - return x - - def shared_step(self, batch): - x = self.get_input(batch, self.first_stage_key) - loss, loss_dict = self(x) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True - ) - - self.log( - "global_step", - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.use_scheduler: - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", - lr, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - return loss - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - _, loss_dict_no_ema = self.shared_step(batch) - with self.ema_scope(): - _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} - self.log_dict( - loss_dict_no_ema, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - ) - self.log_dict( - loss_dict_ema, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - ) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, "n b c h w -> b n c h w") - denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log["inputs"] = x - - # get diffusion row - diffusion_row = list() - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), "1 -> b", b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample( - batch_size=N, return_intermediates=True - ) - - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - -class LatentDiffusion(DDPM): - """main class""" - - def __init__( - self, - first_stage_config, - cond_stage_config, - personalization_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, - **kwargs, - ): - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs["timesteps"] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = "concat" if concat_mode else "crossattn" - if cond_stage_config == "__is_unconditional__": - conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer("scale_factor", torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - - self.model.eval() - self.model.train = disabled_train - for param in self.model.parameters(): - param.requires_grad = False - - self.embedding_manager = self.instantiate_embedding_manager( - personalization_config, self.cond_stage_model - ) - self.textual_inversion_manager = TextualInversionManager( - tokenizer=self.cond_stage_model.tokenizer, - text_encoder=self.cond_stage_model.transformer, - full_precision=True, - ) - # this circular component dependency is gross and bad, needs to be rethought - self.cond_stage_model.set_textual_inversion_manager( - self.textual_inversion_manager - ) - - self.emb_ckpt_counter = 0 - - # if self.embedding_manager.is_clip: - # self.cond_stage_model.update_embedding_func(self.embedding_manager) - - for param in self.embedding_manager.embedding_parameters(): - param.requires_grad = True - - def make_cond_schedule( - self, - ): - self.cond_ids = torch.full( - size=(self.num_timesteps,), - fill_value=self.num_timesteps - 1, - dtype=torch.long, - ) - ids = torch.round( - torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) - ).long() - self.cond_ids[: self.num_timesteps_cond] = ids - - @rank_zero_only - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None): - # only for very first batch - if ( - self.scale_by_std - and self.current_epoch == 0 - and self.global_step == 0 - and batch_idx == 0 - and not self.restarted_from_ckpt - ): - assert ( - self.scale_factor == 1.0 - ), "rather not use custom rescaling and std-rescaling simultaneously" - # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer("scale_factor", 1.0 / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - super().register_schedule( - given_betas, - beta_schedule, - timesteps, - linear_start, - linear_end, - cosine_s, - ) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != "__is_first_stage__" - assert config != "__is_unconditional__" - try: - model = instantiate_from_config(config) - except urllib.error.URLError: - raise SystemExit( - "* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine." - ) - self.cond_stage_model = model - - def instantiate_embedding_manager(self, config, embedder): - model = instantiate_from_config(config, embedder=embedder) - - if config.params.get( - "embedding_manager_ckpt", None - ): # do not load if missing OR empty string - model.load(config.params.embedding_manager_ckpt) - - return model - - def _get_denoise_row_from_list( - self, samples, desc="", force_no_decoder_quantization=False - ): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append( - self.decode_first_stage( - zd.to(self.device), - force_not_quantize=force_no_decoder_quantization, - ) - ) - n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") - denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def get_learned_conditioning(self, c, **kwargs): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, "encode") and callable( - self.cond_stage_model.encode - ): - c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager, **kwargs - ) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.cond_stage_model(c, **kwargs) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) - return c - - def meshgrid(self, h, w): - y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) - x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - - arr = torch.cat([y, x], dim=-1) - return arr - - def delta_border(self, h, w): - """ - :param h: height - :param w: width - :return: normalized distance to image border, - wtith min distance = 0 at border and max dist = 0.5 at image center - """ - lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) - arr = self.meshgrid(h, w) / lower_right_corner - dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] - dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min( - torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 - )[0] - return edge_dist - - def get_weighting(self, h, w, Ly, Lx, device): - weighting = self.delta_border(h, w) - weighting = torch.clip( - weighting, - self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], - ) - weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) - - if self.split_input_params["tie_braker"]: - L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip( - L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"], - ) - - L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) - weighting = weighting * L_weighting - return weighting - - def get_fold_unfold( - self, x, kernel_size, stride, uf=1, df=1 - ): # todo load once not every time, shorten code - """ - :param x: img of size (bs, c, h, w) - :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) - """ - bs, nc, h, w = x.shape - - # number of crops in image - Ly = (h - kernel_size[0]) // stride[0] + 1 - Lx = (w - kernel_size[1]) // stride[1] + 1 - - if uf == 1 and df == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - - weighting = self.get_weighting( - kernel_size[0], kernel_size[1], Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) - - elif uf > 1 and df == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict( - kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, - padding=0, - stride=(stride[0] * uf, stride[1] * uf), - ) - fold = torch.nn.Fold( - output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 - ) - - weighting = self.get_weighting( - kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view( - 1, 1, h * uf, w * uf - ) # normalizes the overlap - weighting = weighting.view( - (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) - ) - - elif df > 1 and uf == 1: - fold_params = dict( - kernel_size=kernel_size, dilation=1, padding=0, stride=stride - ) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict( - kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, - padding=0, - stride=(stride[0] // df, stride[1] // df), - ) - fold = torch.nn.Fold( - output_size=(x.shape[2] // df, x.shape[3] // df), - **fold_params2, - ) - - weighting = self.get_weighting( - kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device - ).to(x.dtype) - normalization = fold(weighting).view( - 1, 1, h // df, w // df - ) # normalizes the overlap - weighting = weighting.view( - (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) - ) - - else: - raise NotImplementedError - - return fold, unfold, normalization, weighting - - @torch.no_grad() - def get_input( - self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - ): - x = super().get_input(batch, k) - if bs is not None: - x = x[:bs] - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - - if self.model.conditioning_key is not None: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ["caption", "coordinates_bbox"]: - xc = batch[cond_key] - elif cond_key == "class_label": - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - if bs is not None: - c = c[:bs] - - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {"pos_x": pos_x, "pos_y": pos_y} - out = [z, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, "b h w c -> b c h w").contiguous() - - z = 1.0 / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) - else: - return self.first_stage_model.decode(z) - - # same as above but without decorator - def differentiable_decode_first_stage( - self, z, predict_cids=False, force_not_quantize=False - ): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, "b h w c -> b c h w").contiguous() - - z = 1.0 / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold( - z, ks, stride, uf=uf - ) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [ - self.first_stage_model.decode( - z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize, - ) - for i in range(z.shape[-1]) - ] - else: - output_list = [ - self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, - force_not_quantize=predict_cids or force_not_quantize, - ) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode( - z, force_not_quantize=predict_cids or force_not_quantize - ) - else: - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] - self.split_input_params["original_image_size"] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold( - x, ks, stride, df=df - ) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - output_list = [ - self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1]) - ] - - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded - - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) - - def shared_step(self, batch, **kwargs): - x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss - - def forward(self, x, c, *args, **kwargs): - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - if self.model.conditioning_key is not None: - assert c is not None - if self.cond_stage_trainable: - c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option - tc = self.cond_ids[t].to(self.device) - c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) - - return self.p_losses(x, c, t, *args, **kwargs) - - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - - def apply_model(self, x_noisy, t, cond, return_ids=False): - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - key = ( - "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" - ) - cond = {key: cond} - - if hasattr(self, "split_input_params"): - assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - - h, w = x_noisy.shape[-2:] - - fold, unfold, normalization, weighting = self.get_fold_unfold( - x_noisy, ks, stride - ) - - z = unfold(x_noisy) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view( - (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - - if ( - self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"] - and self.model.conditioning_key - ): # todo check for completeness - c_key = next(iter(cond.keys())) # get key - c = next(iter(cond.values())) # get value - assert len(c) == 1 # todo extend to list with more than one elem - c = c[0] # get element - - c = unfold(c) - c = c.view( - (c.shape[0], -1, ks[0], ks[1], c.shape[-1]) - ) # (bn, nc, ks[0], ks[1], L ) - - cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] - - elif self.cond_stage_key == "coordinates_bbox": - assert ( - "original_image_size" in self.split_input_params - ), "BoudingBoxRescaling is missing original_image_size" - - # assuming padding of unfold is always 0 and its dilation is always 1 - n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params["original_image_size"] - # as we are operating on latents, we need the factor from the original image size to the - # spatial latent size to properly rescale the crops for regenerating the bbox annotations - num_downs = self.first_stage_model.encoder.num_resolutions - 1 - rescale_latent = 2 ** (num_downs) - - # get top left positions of patches as conforming for the bbbox tokenizer, therefore we - # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [ - ( - rescale_latent - * stride[0] - * (patch_nr % n_patches_per_row) - / full_img_w, - rescale_latent - * stride[1] - * (patch_nr // n_patches_per_row) - / full_img_h, - ) - for patch_nr in range(z.shape[-1]) - ] - - # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [ - ( - x_tl, - y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h, - ) - for x_tl, y_tl in tl_patch_coordinates - ] - # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] - - # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [ - torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to( - self.device - ) - for bbox in patch_limits - ] # list of length l with tensors of shape (1, 2) - print(patch_limits_tknzd[0].shape) - # cut tknzd crop position from conditioning - assert isinstance(cond, dict), "cond must be dict to be fed into model" - cut_cond = cond["c_crossattn"][0][..., :-2].to(self.device) - print(cut_cond.shape) - - adapted_cond = torch.stack( - [torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd] - ) - adapted_cond = rearrange(adapted_cond, "l b n -> (l b) n") - print(adapted_cond.shape) - adapted_cond = self.get_learned_conditioning(adapted_cond) - print(adapted_cond.shape) - adapted_cond = rearrange( - adapted_cond, "(l b) n d -> l b n d", l=z.shape[-1] - ) - print(adapted_cond.shape) - - cond_list = [{"c_crossattn": [e]} for e in adapted_cond] - - else: - cond_list = [ - cond for i in range(z.shape[-1]) - ] # Todo make this more efficient - - # apply model by loop over crops - output_list = [ - self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1]) - ] - assert not isinstance( - output_list[0], tuple - ) # todo cant deal with multiple model outputs check this never happens - - o = torch.stack(output_list, axis=-1) - o = o * weighting - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - x_recon = fold(o) / normalization - - else: - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart - ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl( - mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 - ) - return mean_flat(kl_prior) / np.log(2.0) - - def p_losses(self, x_start, cond, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy, t, cond) - - loss_dict = {} - prefix = "train" if self.training else "val" - - if self.parameterization == "x0": - target = x_start - elif self.parameterization == "eps": - target = noise - else: - raise NotImplementedError() - - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) - - logvar_t = self.logvar[t.item()].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) - loss_dict.update({"logvar": self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) - loss += self.original_elbo_weight * loss_vlb - loss_dict.update({f"{prefix}/loss": loss}) - - if self.embedding_reg_weight > 0: - loss_embedding_reg = ( - self.embedding_manager.embedding_to_coarse_loss().mean() - ) - - loss_dict.update({f"{prefix}/loss_emb_reg": loss_embedding_reg}) - - loss += self.embedding_reg_weight * loss_embedding_reg - loss_dict.update({f"{prefix}/loss": loss}) - - return loss, loss_dict - - def p_mean_variance( - self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None, - ): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score( - self, model_out, x, t, c, **corrector_kwargs - ) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - ( - model_mean, - posterior_variance, - posterior_log_variance, - ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) - if return_codebook_ids: - return ( - model_mean, - posterior_variance, - posterior_log_variance, - logits, - ) - elif return_x0: - return ( - model_mean, - posterior_variance, - posterior_log_variance, - x_recon, - ) - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - ): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance( - x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - - if return_codebook_ids: - return model_mean + nonzero_mask * ( - 0.5 * model_log_variance - ).exp() * noise, logits.argmax(dim=1) - if return_x0: - return ( - model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, - x0, - ) - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising( - self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc="Progressive Generation", - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop( - self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc="Sampling t", - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - ) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample( - self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs, - ): - if shape is None: - shape = ( - batch_size, - self.channels, - self.image_size, - self.image_size, - ) - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - return self.p_sample_loop( - cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0, - ) - - @torch.no_grad() - def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): - if ddim: - ddim_sampler = DDIMSampler(self) - shape = (self.channels, self.image_size, self.image_size) - samples, intermediates = ddim_sampler.sample( - ddim_steps, batch_size, shape, cond, verbose=False, **kwargs - ) - - else: - samples, intermediates = self.sample( - cond=cond, - batch_size=batch_size, - return_intermediates=True, - **kwargs, - ) - - return samples, intermediates - - @torch.no_grad() - def get_unconditional_conditioning(self, batch_size, null_label=None): - if null_label is not None: - xc = null_label - if isinstance(xc, ListConfig): - xc = list(xc) - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - if hasattr(xc, "to"): - xc = xc.to(self.device) - c = self.get_learned_conditioning(xc) - else: - # todo: get null label from cond_stage_model - raise NotImplementedError() - c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) - return c - - @torch.no_grad() - def log_images( - self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=50, - ddim_eta=1.0, - return_keys=None, - quantize_denoised=True, - inpaint=False, - plot_denoise_rows=False, - plot_progressive_rows=False, - plot_diffusion_rows=False, - **kwargs, - ): - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc = self.get_input( - batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N, - ) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) - log["conditioning"] = xc - elif self.cond_stage_key == "class_label": - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log["conditioning"] = xc - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), "1 -> b", b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") - diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - ) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - uc = self.get_learned_conditioning(len(c) * [""]) - sample_scaled, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=5.0, - unconditional_conditioning=uc, - ) - log["samples_scaled"] = self.decode_first_stage(sample_scaled) - - if ( - quantize_denoised - and not isinstance(self.first_stage_model, AutoencoderKL) - and not isinstance(self.first_stage_model, IdentityFirstStage) - ): - # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - quantize_denoised=True, - ) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples - - if inpaint: - # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 - mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask, - ) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask - - # outpaint - with self.ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log( - cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask, - ) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples - - if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising( - c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N, - ) - prog_row = self._get_denoise_row_from_list( - progressives, desc="Progressive Generation" - ) - log["progressive_row"] = prog_row - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - - if self.embedding_manager is not None: - params = list(self.embedding_manager.embedding_parameters()) - # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters()) - else: - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print("Diffusion model optimizing logvar") - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert "target" in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - @rank_zero_only - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - - if os.path.isdir(self.trainer.checkpoint_callback.dirpath): - self.embedding_manager.save( - os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt") - ) - - if (self.global_step - self.emb_ckpt_counter) > 500: - self.embedding_manager.save( - os.path.join( - self.trainer.checkpoint_callback.dirpath, - f"embeddings_gs-{self.global_step}.pt", - ) - ) - - self.emb_ckpt_counter += 500 - - -class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [ - None, - "concat", - "crossattn", - "hybrid", - "adm", - ] - - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == "concat": - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == "crossattn": - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == "hybrid": - cc = torch.cat(c_crossattn, 1) - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == "adm": - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class Layout2ImgDiffusion(LatentDiffusion): - # TODO: move all layout-specific hacks to this class - def __init__(self, cond_stage_key, *args, **kwargs): - assert ( - cond_stage_key == "coordinates_bbox" - ), 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) - - def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) - - key = "train" if self.training else "validation" - dset = self.trainer.datamodule.datasets[key] - mapper = dset.conditional_builders[self.cond_stage_key] - - bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) - for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) - bbox_imgs.append(bboximg) - - cond_img = torch.stack(bbox_imgs, dim=0) - logs["bbox_image"] = cond_img - return logs - - -class LatentInpaintDiffusion(LatentDiffusion): - def __init__( - self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - finetune_keys=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.masked_image_key = masked_image_key - assert self.masked_image_key in concat_keys - self.concat_keys = concat_keys - - @torch.no_grad() - def get_input( - self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False - ): - # note: restricted to non-trainable encoders currently - assert ( - not self.cond_stage_trainable - ), "trainable cond stages not yet supported for inpainting" - z, c, x, xrec, xc = super().get_input( - batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs, - ) - - assert exists(self.concat_keys) - c_cat = list() - for ck in self.concat_keys: - cc = ( - rearrange(batch[ck], "b h w c -> b c h w") - .to(memory_format=torch.contiguous_format) - .float() - ) - if bs is not None: - cc = cc[:bs] - cc = cc.to(self.device) - bchw = z.shape - if ck != self.masked_image_key: - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) - else: - cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} - if return_first_stage_outputs: - return z, all_conds, x, xrec, xc - return z, all_conds diff --git a/invokeai/backend/stable_diffusion/diffusion/ksampler.py b/invokeai/backend/stable_diffusion/diffusion/ksampler.py deleted file mode 100644 index eddcc11ea8..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/ksampler.py +++ /dev/null @@ -1,339 +0,0 @@ -"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" - -import k_diffusion as K -import torch -from torch import nn - -from .cross_attention_map_saving import AttentionMapSaver -from .sampler import Sampler -from .shared_invokeai_diffusion import InvokeAIDiffuserComponent - -# at this threshold, the scheduler will stop using the Karras -# noise schedule and start using the model's schedule -STEP_THRESHOLD = 30 - - -def cfg_apply_threshold(result, threshold=0.0, scale=0.7): - if threshold <= 0.0: - return result - maxval = 0.0 + torch.max(result).cpu().numpy() - minval = 0.0 + torch.min(result).cpu().numpy() - if maxval < threshold and minval > -threshold: - return result - if maxval > threshold: - maxval = min(max(1, scale * maxval), threshold) - if minval < -threshold: - minval = max(min(-1, scale * minval), -threshold) - return torch.clamp(result, min=minval, max=maxval) - - -class CFGDenoiser(nn.Module): - def __init__(self, model, threshold=0, warmup=0): - super().__init__() - self.inner_model = model - self.threshold = threshold - self.warmup_max = warmup - self.warmup = max(warmup / 10, 1) - self.invokeai_diffuser = InvokeAIDiffuserComponent( - model, - model_forward_callback=lambda x, sigma, cond: self.inner_model( - x, sigma, cond=cond - ), - ) - - def prepare_to_sample(self, t_enc, **kwargs): - extra_conditioning_info = kwargs.get("extra_conditioning_info", None) - - if ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ): - self.invokeai_diffuser.override_cross_attention( - extra_conditioning_info, step_count=t_enc - ) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - def forward(self, x, sigma, uncond, cond, cond_scale): - next_x = self.invokeai_diffuser.do_diffusion_step( - x, sigma, uncond, cond, cond_scale - ) - if self.warmup < self.warmup_max: - thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) - self.warmup += 1 - else: - thresh = self.threshold - if thresh > self.threshold: - thresh = self.threshold - return cfg_apply_threshold(next_x, thresh) - - -class KSampler(Sampler): - def __init__(self, model, schedule="lms", device=None, **kwargs): - denoiser = K.external.CompVisDenoiser(model) - super().__init__( - denoiser, - schedule, - steps=model.num_timesteps, - ) - self.sigmas = None - self.ds = None - self.s_in = None - self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD) - if self.karras_max is None: - self.karras_max = STEP_THRESHOLD - - def make_schedule( - self, - ddim_num_steps, - ddim_discretize="uniform", - ddim_eta=0.0, - verbose=False, - ): - outer_model = self.model - self.model = outer_model.inner_model - super().make_schedule( - ddim_num_steps, - ddim_discretize="uniform", - ddim_eta=0.0, - verbose=False, - ) - self.model = outer_model - self.ddim_num_steps = ddim_num_steps - # we don't need both of these sigmas, but storing them here to make - # comparison easier later on - self.model_sigmas = self.model.get_sigmas(ddim_num_steps) - self.karras_sigmas = K.sampling.get_sigmas_karras( - n=ddim_num_steps, - sigma_min=self.model.sigmas[0].item(), - sigma_max=self.model.sigmas[-1].item(), - rho=7.0, - device=self.device, - ) - - if ddim_num_steps >= self.karras_max: - print( - f">> Ksampler using model noise schedule (steps >= {self.karras_max})" - ) - self.sigmas = self.model_sigmas - else: - print( - f">> Ksampler using karras noise schedule (steps < {self.karras_max})" - ) - self.sigmas = self.karras_sigmas - - # ALERT: We are completely overriding the sample() method in the base class, which - # means that inpainting will not work. To get this to work we need to be able to - # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way - # in the lstein/k-diffusion branch. - - @torch.no_grad() - def decode( - self, - z_enc, - cond, - t_enc, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent=None, - mask=None, - **kwargs, - ): - samples, _ = self.sample( - batch_size=1, - S=t_enc, - x_T=z_enc, - shape=z_enc.shape[1:], - conditioning=cond, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - img_callback=img_callback, - x0=init_latent, - mask=mask, - **kwargs, - ) - return samples - - # this is a no-op, provided here for compatibility with ddim and plms samplers - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - return x0 - - # Most of these arguments are ignored and are only present for compatibility with - # other samples - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - attention_maps_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, - threshold=0, - perlin=0, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - def route_callback(k_callback_values): - if img_callback is not None: - img_callback(k_callback_values["x"], k_callback_values["i"]) - - # if make_schedule() hasn't been called, we do it now - if self.sigmas is None: - self.make_schedule( - ddim_num_steps=S, - ddim_eta=eta, - verbose=False, - ) - - # sigmas are set up in make_schedule - we take the last steps items - sigmas = self.sigmas[-S - 1 :] - - # x_T is variation noise. When an init image is provided (in x0) we need to add - # more randomness to the starting image. - if x_T is not None: - if x0 is not None: - x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0] - else: - x = x_T * sigmas[0] - else: - x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - - model_wrap_cfg = CFGDenoiser( - self.model, threshold=threshold, warmup=max(0.8 * S, S - 10) - ) - model_wrap_cfg.prepare_to_sample( - S, extra_conditioning_info=extra_conditioning_info - ) - - # setup attention maps saving. checks for None are because there are multiple code paths to get here. - attention_map_saver = None - if attention_maps_callback is not None and extra_conditioning_info is not None: - eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - attention_map_token_ids = range(1, eos_token_index) - attention_map_saver = AttentionMapSaver( - token_ids=attention_map_token_ids, latents_shape=x.shape[-2:] - ) - model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving( - attention_map_saver - ) - - extra_args = { - "cond": conditioning, - "uncond": unconditional_conditioning, - "cond_scale": unconditional_guidance_scale, - } - print( - f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)" - ) - sampling_result = ( - K.sampling.__dict__[f"sample_{self.schedule}"]( - model_wrap_cfg, - x, - sigmas, - extra_args=extra_args, - callback=route_callback, - ), - None, - ) - if attention_map_saver is not None: - attention_maps_callback(attention_map_saver) - return sampling_result - - # this code will support inpainting if and when ksampler API modified or - # a workaround is found. - @torch.no_grad() - def p_sample( - self, - img, - cond, - ts, - index, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - extra_conditioning_info=None, - **kwargs, - ): - if self.model_wrap is None: - self.model_wrap = CFGDenoiser(self.model) - extra_args = { - "cond": cond, - "uncond": unconditional_conditioning, - "cond_scale": unconditional_guidance_scale, - } - if self.s_in is None: - self.s_in = img.new_ones([img.shape[0]]) - if self.ds is None: - self.ds = [] - - # terrible, confusing names here - steps = self.ddim_num_steps - t_enc = self.t_enc - - # sigmas is a full steps in length, but t_enc might - # be less. We start in the middle of the sigma array - # and work our way to the end after t_enc steps. - # index starts at t_enc and works its way to zero, - # so the actual formula for indexing into sigmas: - # sigma_index = (steps-index) - s_index = t_enc - index - 1 - self.model_wrap.prepare_to_sample( - s_index, extra_conditioning_info=extra_conditioning_info - ) - img = K.sampling.__dict__[f"_{self.schedule}"]( - self.model_wrap, - img, - self.sigmas, - s_index, - s_in=self.s_in, - ds=self.ds, - extra_args=extra_args, - ) - - return img, None, None - - # REVIEW THIS METHOD: it has never been tested. In particular, - # we should not be multiplying by self.sigmas[0] if we - # are at an intermediate step in img2img. See similar in - # sample() which does work. - def get_initial_image(self, x_T, shape, steps): - print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing") - x = torch.randn(shape, device=self.device) * self.sigmas[0] - if x_T is not None: - return x_T + x - else: - return x - - def prepare_to_sample(self, t_enc, **kwargs): - self.t_enc = t_enc - self.model_wrap = None - self.ds = None - self.s_in = None - - def q_sample(self, x0, ts): - """ - Overrides parent method to return the q_sample of the inner model. - """ - return self.model.inner_model.q_sample(x0, ts) - - def conditioning_key(self) -> str: - return self.model.inner_model.model.conditioning_key diff --git a/invokeai/backend/stable_diffusion/diffusion/plms.py b/invokeai/backend/stable_diffusion/diffusion/plms.py deleted file mode 100644 index df37afcc24..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/plms.py +++ /dev/null @@ -1,143 +0,0 @@ -"""SAMPLING ONLY.""" - -from functools import partial - -import numpy as np -import torch -from tqdm import tqdm - -from ...util import choose_torch_device -from ..diffusionmodules.util import noise_like -from .sampler import Sampler -from .shared_invokeai_diffusion import InvokeAIDiffuserComponent - - -class PLMSSampler(Sampler): - def __init__(self, model, schedule="linear", device=None, **kwargs): - super().__init__(model, schedule, model.num_timesteps, device) - - def prepare_to_sample(self, t_enc, **kwargs): - super().prepare_to_sample(t_enc, **kwargs) - - extra_conditioning_info = kwargs.get("extra_conditioning_info", None) - all_timesteps_count = kwargs.get("all_timesteps_count", t_enc) - - if ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ): - self.invokeai_diffuser.override_cross_attention( - extra_conditioning_info, step_count=all_timesteps_count - ) - else: - self.invokeai_diffuser.restore_default_cross_attention() - - # this is the essential routine - @torch.no_grad() - def p_sample( - self, - x, # image, called 'img' elsewhere - c, # conditioning, called 'cond' elsewhere - t, # timesteps, called 'ts' elsewhere - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=[], - t_next=None, - step_count: int = 1000, # total number of steps - **kwargs, - ): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 - ): - # damian0815 would like to know when/if this code path is used - e_t = self.model.apply_model(x, t, c) - else: - # step_index counts in the opposite direction to index - step_index = step_count - (index + 1) - e_t = self.invokeai_diffuser.do_diffusion_step( - x, - t, - unconditional_conditioning, - c, - unconditional_guidance_scale, - step_index=step_index, - ) - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = ( - 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] - ) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/invokeai/backend/stable_diffusion/diffusion/sampler.py b/invokeai/backend/stable_diffusion/diffusion/sampler.py deleted file mode 100644 index beb74eaefb..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/sampler.py +++ /dev/null @@ -1,454 +0,0 @@ -""" -invokeai.models.diffusion.sampler - -Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc -""" -from functools import partial - -import numpy as np -import torch -from tqdm import tqdm - -from ...util import choose_torch_device -from ..diffusionmodules.util import ( - extract_into_tensor, - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, -) -from .shared_invokeai_diffusion import InvokeAIDiffuserComponent - - -class Sampler(object): - def __init__(self, model, schedule="linear", steps=None, device=None, **kwargs): - self.model = model - self.ddim_timesteps = None - self.ddpm_num_timesteps = steps - self.schedule = schedule - self.device = device or choose_torch_device() - self.invokeai_diffuser = InvokeAIDiffuserComponent( - self.model, - model_forward_callback=lambda x, sigma, cond: self.model.apply_model( - x, sigma, cond - ), - ) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device(self.device): - attr = attr.to(torch.float32).to(torch.device(self.device)) - setattr(self, name, attr) - - # This method was copied over from ddim.py and probably does stuff that is - # ddim-specific. Disentangle at some point. - def make_schedule( - self, - ddim_num_steps, - ddim_discretize="uniform", - ddim_eta=0.0, - verbose=False, - ): - self.total_steps = ddim_num_steps - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", - to_torch(np.log(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ( - ddim_sigmas, - ddim_alphas, - ddim_alphas_prev, - ) = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", - sigmas_for_original_sampling_steps, - ) - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise - ) - - @torch.no_grad() - def sample( - self, - S, # S is steps - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change. - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=False, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): - ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print( - f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" - ) - else: - if conditioning.shape[0] != batch_size: - print( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" - ) - - # check to see if make_schedule() has run, and if not, run it - if self.ddim_timesteps is None: - self.make_schedule( - ddim_num_steps=S, - ddim_eta=eta, - verbose=False, - ) - - ts = self.get_timesteps(S) - - # sampling - C, H, W = shape - shape = (batch_size, C, H, W) - samples, intermediates = self.do_sampling( - conditioning, - shape, - timesteps=ts, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - steps=S, - **kwargs, - ) - return samples, intermediates - - @torch.no_grad() - def do_sampling( - self, - cond, - shape, - timesteps=None, - x_T=None, - ddim_use_original_steps=False, - callback=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - steps=None, - **kwargs, - ): - b = shape[0] - time_range = ( - list(reversed(range(0, timesteps))) - if ddim_use_original_steps - else np.flip(timesteps) - ) - - total_steps = steps - - iterator = tqdm( - time_range, - desc=f"{self.__class__.__name__}", - total=total_steps, - dynamic_ncols=True, - ) - old_eps = [] - self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=steps, **kwargs) - img = self.get_initial_image(x_T, shape, total_steps) - - # probably don't need this at all - intermediates = {"x_inter": [img], "pred_x0": [img]} - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=self.device, dtype=torch.long) - ts_next = torch.full( - (b,), - time_range[min(i + 1, len(time_range) - 1)], - device=self.device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - step_count=steps, - ) - img, pred_x0, e_t = outs - - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - # NOTE that decode() and sample() are almost the same code, and do the same thing. - # The variable names are changed in order to be confusing. - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent=None, - mask=None, - all_timesteps_count=None, - **kwargs, - ): - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print( - f">> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)" - ) - - iterator = tqdm(time_range, desc="Decoding image", total=total_steps) - x_dec = x_latent - x0 = init_latent - self.prepare_to_sample( - t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs - ) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), - step, - device=x_latent.device, - dtype=torch.long, - ) - - ts_next = torch.full( - (x_latent.shape[0],), - time_range[min(i + 1, len(time_range) - 1)], - device=self.device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? - x_dec = xdec_orig * mask + (1.0 - mask) * x_dec - - outs = self.p_sample( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - t_next=ts_next, - step_count=len(self.ddim_timesteps), - ) - - x_dec, pred_x0, e_t = outs - if img_callback: - img_callback(x_dec, i) - - return x_dec - - def get_initial_image(self, x_T, shape, timesteps=None): - if x_T is None: - return torch.randn(shape, device=self.device) - else: - return x_T - - def p_sample( - self, - img, - cond, - ts, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, - steps=None, - ): - raise NotImplementedError( - "p_sample() must be implemented in a descendent class" - ) - - def prepare_to_sample(self, t_enc, **kwargs): - """ - Hook that will be called right before the very first invocation of p_sample() - to allow subclass to do additional initialization. t_enc corresponds to the actual - number of steps that will be run, and may be less than total steps if img2img is - active. - """ - pass - - def get_timesteps(self, ddim_steps): - """ - The ddim and plms samplers work on timesteps. This method is called after - ddim_timesteps are created in make_schedule(), and selects the portion of - timesteps that will be used for sampling, depending on the t_enc in img2img. - """ - return self.ddim_timesteps[:ddim_steps] - - def q_sample(self, x0, ts): - """ - Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to - return self.model.inner_model.q_sample(x0,ts) - """ - return self.model.q_sample(x0, ts) - - def conditioning_key(self) -> str: - return self.model.model.conditioning_key - - def uses_inpainting_model(self) -> bool: - return self.conditioning_key() in ("hybrid", "concat") - - def adjust_settings(self, **kwargs): - """ - This is a catch-all method for adjusting any instance variables - after the sampler is instantiated. No type-checking performed - here, so use with care! - """ - for k in kwargs.keys(): - try: - setattr(self, k, kwargs[k]) - except AttributeError: - print( - f"** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}" - ) diff --git a/invokeai/backend/stable_diffusion/diffusionmodules/__init__.py b/invokeai/backend/stable_diffusion/diffusionmodules/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/stable_diffusion/diffusionmodules/model.py b/invokeai/backend/stable_diffusion/diffusionmodules/model.py deleted file mode 100644 index 62cb45d508..0000000000 --- a/invokeai/backend/stable_diffusion/diffusionmodules/model.py +++ /dev/null @@ -1,1081 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import gc -import math - -import numpy as np -import psutil -import torch -import torch.nn as nn -from einops import rearrange -from torch.nn.functional import silu - -from ...util import instantiate_from_config -from ..attention import LinearAttention - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - cpu_m1_cond = ( - True - if hasattr(torch.backends, "mps") - and torch.backends.mps.is_available() - and x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 - else False - ) - if cpu_m1_cond: - x = x.to("cpu") # send to cpu - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - if cpu_m1_cond: - x = x.to("mps") # return to mps - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - x_size = x.size() - if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0: - self.to("cpu") - x = x.to("cpu") - else: - self.to("mps") - x = x.to("mps") - h = self.norm1(x) - h = silu(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(silu(temb))[:, :, None, None] - - h = self.norm2(h) - h = silu(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h * w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h * w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - if q.device.type == "cuda": - stats = torch.cuda.memory_stats(q.device) - mem_active = stats["active_bytes.all.current"] - mem_reserved = stats["reserved_bytes.all.current"] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = ( - q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - ) - - else: - if psutil.virtual_memory().available / (1024**3) < 12: - slice_size = 1 - else: - slice_size = min( - q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])) - ) - - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c) ** (-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2) - del w2 - - # attend to values - v1 = v.reshape(b, c, h * w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm( - v1, w4 - ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" - print(f" | Making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = silu(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = silu(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = silu(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print( - " | Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape) - ) - ) - - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # prepare for up sampling - gc.collect() - if h.device.type == "cuda": - torch.cuda.empty_cache() - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = silu(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1, 2, 3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = silu(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = silu(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" - ) - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) - - def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: - return x - else: - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) - return x - - -class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): - super().__init__() - if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model - else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, out_channels=m * n_channels, dropout=dropout - ) - ) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def encode_with_pretrained(self, x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self, x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = silu(z) - - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z, "b c h w -> b (h w) c") - return z diff --git a/invokeai/backend/stable_diffusion/diffusionmodules/openaimodel.py b/invokeai/backend/stable_diffusion/diffusionmodules/openaimodel.py deleted file mode 100644 index 867a1a30ca..0000000000 --- a/invokeai/backend/stable_diffusion/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,1009 +0,0 @@ -import math -from abc import abstractmethod -from functools import partial -from typing import Iterable - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.util import ( - avg_pool_nd, - checkpoint, - conv_nd, - linear, - normalization, - timestep_embedding, - zero_module, -) - - -# dummy replace -def convert_module_to_f16(x): - pass - - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, 3, padding=padding - ) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class TransposedUpsample(nn.Module): - """Learned 2x upsampling without padding""" - - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d( - self.channels, self.out_channels, kernel_size=ks, stride=2 - ) - - def forward(self, x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint( - self._forward, (x,), self.parameters(), True - ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - # return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - ): - super().__init__() - if use_spatial_transformer: - assert ( - context_dim is not None - ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." - - if context_dim is not None: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - from omegaconf.listconfig import ListConfig - - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) diff --git a/invokeai/backend/stable_diffusion/diffusionmodules/util.py b/invokeai/backend/stable_diffusion/diffusionmodules/util.py deleted file mode 100644 index b71b0f06f9..0000000000 --- a/invokeai/backend/stable_diffusion/diffusionmodules/util.py +++ /dev/null @@ -1,297 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import math -import os - -import numpy as np -import torch -import torch.nn as nn -from einops import repeat - -from ...util.util import instantiate_from_config - - -def make_beta_schedule( - schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, - linear_end**0.5, - n_timestep, - dtype=torch.float64, - ) - ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace( - linear_start, linear_end, n_timestep, dtype=torch.float64 - ) - elif schedule == "sqrt": - betas = ( - torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - ** 0.5 - ) - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps( - ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True -): - if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - if c < 1: - c = 1 - ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int) - elif ddim_discr_method == "quad": - ddim_timesteps = ( - (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 - ).astype(int) - else: - raise NotImplementedError( - f'There is no ddim discretization method called "{ddim_discr_method}"' - ) - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - # steps_out = ddim_timesteps - - if verbose: - print(f"Selected timesteps for ddim sampler: {steps_out}") - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt( - (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) - ) - if verbose: - print( - f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" - ) - print( - f"For the chosen value of eta, which is {eta}, " - f"this results in the following sigma_t schedule for ddim sampler {sigmas}" - ) - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if False: # disabled checkpointing to allow requires_grad = False for main model - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( - shape[0], *((1,) * (len(shape) - 1)) - ) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() diff --git a/invokeai/backend/stable_diffusion/distributions/__init__.py b/invokeai/backend/stable_diffusion/distributions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/stable_diffusion/distributions/distributions.py b/invokeai/backend/stable_diffusion/distributions/distributions.py deleted file mode 100644 index 016be35523..0000000000 --- a/invokeai/backend/stable_diffusion/distributions/distributions.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy as np -import torch - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/invokeai/backend/stable_diffusion/ema.py b/invokeai/backend/stable_diffusion/ema.py deleted file mode 100644 index 880ca3d205..0000000000 --- a/invokeai/backend/stable_diffusion/ema.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/invokeai/backend/stable_diffusion/encoders/__init__.py b/invokeai/backend/stable_diffusion/encoders/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/stable_diffusion/encoders/modules.py b/invokeai/backend/stable_diffusion/encoders/modules.py deleted file mode 100644 index 54afd12bc9..0000000000 --- a/invokeai/backend/stable_diffusion/encoders/modules.py +++ /dev/null @@ -1,858 +0,0 @@ -import math -from functools import partial -from typing import Optional - -import clip -import kornia -import torch -import torch.nn as nn -from einops import repeat -from transformers import CLIPTextModel, CLIPTokenizer - -from ...util import choose_torch_device -from ..globals import global_cache_dir -from ..x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test - Encoder, - TransformerWrapper, -) - - -def _expand_mask(mask, dtype, tgt_len=None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -def _build_causal_attention_mask(bsz, seq_len, dtype): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) - mask.fill_(torch.tensor(torch.finfo(dtype).min)) - mask.triu_(1) # zero out the lower diagonal - mask = mask.unsqueeze(1) # expand mask - return mask - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key="class"): - super().__init__() - self.key = key - self.embedding = nn.Embedding(n_classes, embed_dim) - - def forward(self, batch, key=None): - if key is None: - key = self.key - # this is for use in crossattn - c = batch[key][:, None] - c = self.embedding(c) - return c - - -class TransformerEmbedder(AbstractEncoder): - """Some transformer encoder layers""" - - def __init__( - self, - n_embed, - n_layer, - vocab_size, - max_seq_len=77, - device=choose_torch_device(), - ): - super().__init__() - self.device = device - self.transformer = TransformerWrapper( - num_tokens=vocab_size, - max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - ) - - def forward(self, tokens): - tokens = tokens.to(self.device) # meh - z = self.transformer(tokens, return_embeddings=True) - return z - - def encode(self, x): - return self(x) - - -class BERTTokenizer(AbstractEncoder): - """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - - def __init__(self, device=choose_torch_device(), vq_interface=True, max_length=77): - super().__init__() - from transformers import BertTokenizerFast - - cache = global_cache_dir("hub") - try: - self.tokenizer = BertTokenizerFast.from_pretrained( - "bert-base-uncased", cache_dir=cache, local_files_only=True - ) - except OSError: - raise SystemExit( - "* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine." - ) - self.device = device - self.vq_interface = vq_interface - self.max_length = max_length - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - return tokens - - @torch.no_grad() - def encode(self, text): - tokens = self(text) - if not self.vq_interface: - return tokens - return None, None, [None, None, tokens] - - def decode(self, text): - return text - - -class BERTEmbedder(AbstractEncoder): - """Uses the BERT tokenizr model and add some transformer encoder layers""" - - def __init__( - self, - n_embed, - n_layer, - vocab_size=30522, - max_seq_len=77, - device=choose_torch_device(), - use_tokenizer=True, - embedding_dropout=0.0, - ): - super().__init__() - self.use_tknz_fn = use_tokenizer - if self.use_tknz_fn: - self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) - self.device = device - self.transformer = TransformerWrapper( - num_tokens=vocab_size, - max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - emb_dropout=embedding_dropout, - ) - - def forward(self, text, embedding_manager=None): - if self.use_tknz_fn: - tokens = self.tknz_fn(text) # .to(self.device) - else: - tokens = text - z = self.transformer( - tokens, return_embeddings=True, embedding_manager=embedding_manager - ) - return z - - def encode(self, text, **kwargs): - # output of length 77 - return self(text, **kwargs) - - -class SpatialRescaler(nn.Module): - def __init__( - self, - n_stages=1, - method="bilinear", - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False, - ): - super().__init__() - self.n_stages = n_stages - assert self.n_stages >= 0 - assert method in [ - "nearest", - "linear", - "bilinear", - "trilinear", - "bicubic", - "area", - ] - self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) - self.remap_output = out_channels is not None - if self.remap_output: - print( - f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." - ) - self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) - - def forward(self, x): - for stage in range(self.n_stages): - x = self.interpolator(x, scale_factor=self.multiplier) - - if self.remap_output: - x = self.channel_mapper(x) - return x - - def encode(self, x): - return self(x) - - -class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - - tokenizer: CLIPTokenizer - transformer: CLIPTextModel - - def __init__( - self, - version: str = "openai/clip-vit-large-patch14", - max_length: int = 77, - tokenizer: Optional[CLIPTokenizer] = None, - transformer: Optional[CLIPTextModel] = None, - ): - super().__init__() - cache = global_cache_dir("hub") - self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( - version, cache_dir=cache, local_files_only=True - ) - self.transformer = transformer or CLIPTextModel.from_pretrained( - version, cache_dir=cache, local_files_only=True - ) - self.max_length = max_length - self.freeze() - - def embedding_forward( - self, - input_ids=None, - position_ids=None, - inputs_embeds=None, - embedding_manager=None, - ) -> torch.Tensor: - seq_length = ( - input_ids.shape[-1] - if input_ids is not None - else inputs_embeds.shape[-2] - ) - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - - if embedding_manager is not None: - inputs_embeds = embedding_manager(input_ids, inputs_embeds) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - - self.transformer.text_model.embeddings.forward = embedding_forward.__get__( - self.transformer.text_model.embeddings - ) - - def encoder_forward( - self, - inputs_embeds, - attention_mask=None, - causal_attention_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - return hidden_states - - self.transformer.text_model.encoder.forward = encoder_forward.__get__( - self.transformer.text_model.encoder - ) - - def text_encoder_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is None: - raise ValueError("You have to specify either input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - embedding_manager=embedding_manager, - ) - - bsz, seq_len = input_shape - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _build_causal_attention_mask( - bsz, seq_len, hidden_states.dtype - ).to(hidden_states.device) - - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) - - last_hidden_state = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = self.final_layer_norm(last_hidden_state) - - return last_hidden_state - - self.transformer.text_model.forward = text_encoder_forward.__get__( - self.transformer.text_model - ) - - def transformer_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - embedding_manager=embedding_manager, - ) - - self.transformer.forward = transformer_forward.__get__(self.transformer) - - def freeze(self): - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text, **kwargs): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - z = self.transformer(input_ids=tokens, **kwargs) - - return z - - def encode(self, text, **kwargs): - return self(text, **kwargs) - - @property - def device(self): - return self.transformer.device - - @device.setter - def device(self, device): - self.transformer.to(device=device) - - -class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): - fragment_weights_key = "fragment_weights" - return_tokens_key = "return_tokens" - - def set_textual_inversion_manager(self, manager): # TextualInversionManager): - # TODO all of the weighting and expanding stuff needs be moved out of this class - self.textual_inversion_manager = manager - - def forward(self, text: list, **kwargs): - # TODO all of the weighting and expanding stuff needs be moved out of this class - """ - - :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different - weights shall be applied. - :param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights - for the prompt fragments. In this case text must contain batches of lists of prompt fragments. - :return: A tensor of shape (B, 77, 768) containing weighted embeddings - """ - if self.fragment_weights_key not in kwargs: - # fallback to base class implementation - return super().forward(text, **kwargs) - - fragment_weights = kwargs[self.fragment_weights_key] - # self.transformer doesn't like receiving "fragment_weights" as an argument - kwargs.pop(self.fragment_weights_key) - - should_return_tokens = False - if self.return_tokens_key in kwargs: - should_return_tokens = kwargs.get(self.return_tokens_key, False) - # self.transformer doesn't like having extra kwargs - kwargs.pop(self.return_tokens_key) - - batch_z = None - batch_tokens = None - for fragments, weights in zip(text, fragment_weights): - # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively - # applying a multiplier to the CFG scale on a per-token basis). - # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept - # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active - # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to - # "red" is to tell SD that it should almost completely *ignore* redness). - # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt - # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the - # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. - - # handle weights >=1 - tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) - base_embedding = self.build_weighted_embedding_tensor( - tokens, per_token_weights, **kwargs - ) - - # this is our starting point - embeddings = base_embedding.unsqueeze(0) - per_embedding_weights = [1.0] - - # now handle weights <1 - # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped - # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting - # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words - # removed. - # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding - # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it - # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". - for index, fragment_weight in enumerate(weights): - if fragment_weight < 1: - fragments_without_this = fragments[:index] + fragments[index + 1 :] - weights_without_this = weights[:index] + weights[index + 1 :] - tokens, per_token_weights = self.get_tokens_and_weights( - fragments_without_this, weights_without_this - ) - embedding_without_this = self.build_weighted_embedding_tensor( - tokens, per_token_weights, **kwargs - ) - - embeddings = torch.cat( - (embeddings, embedding_without_this.unsqueeze(0)), dim=1 - ) - # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 - # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding - # therefore: - # fragment_weight = 1: we are at base_z => lerp weight 0 - # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 - # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf - # so let's use tan(), because: - # tan is 0.0 at 0, - # 1.0 at PI/4, and - # inf at PI/2 - # -> tan((1-weight)*PI/2) should give us ideal lerp weights - epsilon = 1e-9 - fragment_weight = max(epsilon, fragment_weight) # inf is bad - embedding_lerp_weight = math.tan( - (1.0 - fragment_weight) * math.pi / 2 - ) - # todo handle negative weight? - - per_embedding_weights.append(embedding_lerp_weight) - - lerped_embeddings = self.apply_embedding_weights( - embeddings, per_embedding_weights, normalize=True - ).squeeze(0) - - # print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") - - # append to batch - batch_z = ( - lerped_embeddings.unsqueeze(0) - if batch_z is None - else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) - ) - batch_tokens = ( - tokens.unsqueeze(0) - if batch_tokens is None - else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) - ) - - # should have shape (B, 77, 768) - # print(f"assembled all tokens into tensor of shape {batch_z.shape}") - - if should_return_tokens: - return batch_z, batch_tokens - else: - return batch_z - - def get_token_ids( - self, fragments: list[str], include_start_and_end_markers: bool = True - ) -> list[list[int]]: - """ - Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like - `[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if - `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length - (typically 75 tokens + eos/bos markers). - - :param fragments: The strings to convert. - :param include_start_and_end_markers: - :return: - """ - - # for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib) - token_ids_list = self.tokenizer( - fragments, - truncation=True, - max_length=self.max_length, - return_overflowing_tokens=False, - padding="do_not_pad", - return_tensors=None, # just give me lists of ints - )["input_ids"] - - result = [] - for token_ids in token_ids_list: - # trim eos/bos - token_ids = token_ids[1:-1] - # pad for textual inversions with vector length >1 - token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary( - token_ids - ) - # restrict length to max_length-2 (leaving room for bos/eos) - token_ids = token_ids[0 : self.max_length - 2] - # add back eos/bos if requested - if include_start_and_end_markers: - token_ids = ( - [self.tokenizer.bos_token_id] - + token_ids - + [self.tokenizer.eos_token_id] - ) - - result.append(token_ids) - - return result - - @classmethod - def apply_embedding_weights( - self, - embeddings: torch.Tensor, - per_embedding_weights: list[float], - normalize: bool, - ) -> torch.Tensor: - per_embedding_weights = torch.tensor( - per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device - ) - if normalize: - per_embedding_weights = per_embedding_weights / torch.sum( - per_embedding_weights - ) - reshaped_weights = per_embedding_weights.reshape( - per_embedding_weights.shape - + ( - 1, - 1, - ) - ) - # reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) - return torch.sum(embeddings * reshaped_weights, dim=1) - # lerped embeddings has shape (77, 768) - - def get_tokens_and_weights( - self, fragments: list[str], weights: list[float] - ) -> (torch.Tensor, torch.Tensor): - """ - - :param fragments: - :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. - :return: - """ - # empty is meaningful - if len(fragments) == 0 and len(weights) == 0: - fragments = [""] - weights = [1] - per_fragment_token_ids = self.get_token_ids( - fragments, include_start_and_end_markers=False - ) - all_token_ids = [] - per_token_weights = [] - # print("all fragments:", fragments, weights) - for index, fragment in enumerate(per_fragment_token_ids): - weight = float(weights[index]) - # print("processing fragment", fragment, weight) - this_fragment_token_ids = per_fragment_token_ids[index] - # print("fragment", fragment, "processed to", this_fragment_token_ids) - # append - all_token_ids += this_fragment_token_ids - # fill out weights tensor with one float per token - per_token_weights += [weight] * len(this_fragment_token_ids) - - # leave room for bos/eos - max_token_count_without_bos_eos_markers = self.max_length - 2 - if len(all_token_ids) > max_token_count_without_bos_eos_markers: - excess_token_count = ( - len(all_token_ids) - max_token_count_without_bos_eos_markers - ) - # TODO build nice description string of how the truncation was applied - # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to - # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. - print( - f">> Prompt is {excess_token_count} token(s) too long and has been truncated" - ) - all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers] - per_token_weights = per_token_weights[ - 0:max_token_count_without_bos_eos_markers - ] - - # pad out to a 77-entry array: [bos_token, , eos_token, pad_token…] - # (77 = self.max_length) - all_token_ids = ( - [self.tokenizer.bos_token_id] - + all_token_ids - + [self.tokenizer.eos_token_id] - ) - per_token_weights = [1.0] + per_token_weights + [1.0] - pad_length = self.max_length - len(all_token_ids) - all_token_ids += [self.tokenizer.pad_token_id] * pad_length - per_token_weights += [1.0] * pad_length - - all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to( - self.device - ) - per_token_weights_tensor = torch.tensor( - per_token_weights, dtype=torch.float32 - ).to(self.device) - # print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") - return all_token_ids_tensor, per_token_weights_tensor - - def build_weighted_embedding_tensor( - self, - token_ids: torch.Tensor, - per_token_weights: torch.Tensor, - weight_delta_from_empty=True, - **kwargs, - ) -> torch.Tensor: - """ - Build a tensor representing the passed-in tokens, each of which has a weight. - :param token_ids: A tensor of shape (77) containing token ids (integers) - :param per_token_weights: A tensor of shape (77) containing weights (floats) - :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector - :param kwargs: passed on to self.transformer() - :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. - """ - # print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") - if token_ids.shape != torch.Size([self.max_length]): - raise ValueError( - f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]" - ) - - z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs) - - batch_weights_expanded = per_token_weights.reshape( - per_token_weights.shape + (1,) - ).expand(z.shape) - - if weight_delta_from_empty: - empty_tokens = self.tokenizer( - [""] * z.shape[0], - truncation=True, - max_length=self.max_length, - padding="max_length", - return_tensors="pt", - )["input_ids"].to(self.device) - empty_z = self.transformer(input_ids=empty_tokens, **kwargs) - z_delta_from_empty = z - empty_z - weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) - - # weighted_z_delta_from_empty = (weighted_z-empty_z) - # print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) - - # print("using empty-delta method, first 5 rows:") - # print(weighted_z[:5]) - - return weighted_z - - else: - original_mean = z.mean() - z *= batch_weights_expanded - after_weighting_mean = z.mean() - # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does - mean_correction_factor = original_mean / after_weighting_mean - z *= mean_correction_factor - return z - - -class FrozenCLIPTextEmbedder(nn.Module): - """ - Uses the CLIP transformer encoder for text. - """ - - def __init__( - self, - version="ViT-L/14", - device=choose_torch_device(), - max_length=77, - n_repeat=1, - normalize=True, - ): - super().__init__() - self.model, _ = clip.load(version, jit=False, device=device) - self.device = device - self.max_length = max_length - self.n_repeat = n_repeat - self.normalize = normalize - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = clip.tokenize(text).to(self.device) - z = self.model.encode_text(tokens) - if self.normalize: - z = z / torch.linalg.norm(z, dim=1, keepdim=True) - return z - - def encode(self, text): - z = self(text) - if z.ndim == 2: - z = z[:, None, :] - z = repeat(z, "b 1 d -> b k d", k=self.n_repeat) - return z - - -class FrozenClipImageEmbedder(nn.Module): - """ - Uses the CLIP image encoder. - """ - - def __init__( - self, - model, - jit=False, - device=choose_torch_device(), - antialias=False, - ): - super().__init__() - self.model, _ = clip.load(name=model, device=device, jit=jit) - - self.antialias = antialias - - self.register_buffer( - "mean", - torch.Tensor([0.48145466, 0.4578275, 0.40821073]), - persistent=False, - ) - self.register_buffer( - "std", - torch.Tensor([0.26862954, 0.26130258, 0.27577711]), - persistent=False, - ) - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize( - x, - (224, 224), - interpolation="bicubic", - align_corners=True, - antialias=self.antialias, - ) - x = (x + 1.0) / 2.0 - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x - - def forward(self, x): - # x is assumed to be in range [-1,1] - return self.model.encode_image(self.preprocess(x)) - - -if __name__ == "__main__": - from ...util.util import count_params - - model = FrozenCLIPEmbedder() - count_params(model, verbose=True) diff --git a/invokeai/backend/stable_diffusion/losses/__init__.py b/invokeai/backend/stable_diffusion/losses/__init__.py deleted file mode 100644 index d86294210c..0000000000 --- a/invokeai/backend/stable_diffusion/losses/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/invokeai/backend/stable_diffusion/losses/contperceptual.py b/invokeai/backend/stable_diffusion/losses/contperceptual.py deleted file mode 100644 index 1e3e6a00c4..0000000000 --- a/invokeai/backend/stable_diffusion/losses/contperceptual.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -import torch.nn as nn -from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? - - -class LPIPSWithDiscriminator(nn.Module): - def __init__( - self, - disc_start, - logvar_init=0.0, - kl_weight=1.0, - pixelloss_weight=1.0, - disc_num_layers=3, - disc_in_channels=3, - disc_factor=1.0, - disc_weight=1.0, - perceptual_weight=1.0, - use_actnorm=False, - disc_conditional=False, - disc_loss="hinge", - ): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - self.discriminator = NLayerDiscriminator( - input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ).apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad( - nll_loss, self.last_layer[0], retain_graph=True - )[0] - g_grads = torch.autograd.grad( - g_loss, self.last_layer[0], retain_graph=True - )[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward( - self, - inputs, - reconstructions, - posteriors, - optimizer_idx, - global_step, - last_layer=None, - cond=None, - split="train", - weights=None, - ): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss( - inputs.contiguous(), reconstructions.contiguous() - ) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights * nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator( - torch.cat((reconstructions.contiguous(), cond), dim=1) - ) - g_loss = -torch.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight( - nll_loss, g_loss, last_layer=last_layer - ) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - else: - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight( - self.disc_factor, - global_step, - threshold=self.discriminator_iter_start, - ) - loss = ( - weighted_nll_loss - + self.kl_weight * kl_loss - + d_weight * disc_factor * g_loss - ) - - log = { - "{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator( - torch.cat((inputs.contiguous().detach(), cond), dim=1) - ) - logits_fake = self.discriminator( - torch.cat((reconstructions.contiguous().detach(), cond), dim=1) - ) - - disc_factor = adopt_weight( - self.disc_factor, - global_step, - threshold=self.discriminator_iter_start, - ) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = { - "{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean(), - } - return d_loss, log diff --git a/invokeai/backend/stable_diffusion/losses/vqperceptual.py b/invokeai/backend/stable_diffusion/losses/vqperceptual.py deleted file mode 100644 index 50413d37b8..0000000000 --- a/invokeai/backend/stable_diffusion/losses/vqperceptual.py +++ /dev/null @@ -1,222 +0,0 @@ -import torch -import torch.nn.functional as F -from einops import repeat -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss -from torch import nn - - -def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): - assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) - loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) - loss_real = (weights * loss_real).sum() / weights.sum() - loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def adopt_weight(weight, global_step, threshold=0, value=0.0): - if global_step < threshold: - weight = value - return weight - - -def measure_perplexity(predicted_indices, n_embed): - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use - - -def l1(x, y): - return torch.abs(x - y) - - -def l2(x, y): - return torch.pow((x - y), 2) - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__( - self, - disc_start, - codebook_weight=1.0, - pixelloss_weight=1.0, - disc_num_layers=3, - disc_in_channels=3, - disc_factor=1.0, - disc_weight=1.0, - perceptual_weight=1.0, - use_actnorm=False, - disc_conditional=False, - disc_ndf=64, - disc_loss="hinge", - n_classes=None, - perceptual_loss="lpips", - pixel_loss="l1", - ): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - assert perceptual_loss in ["lpips", "clips", "dists"] - assert pixel_loss in ["l1", "l2"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") - self.perceptual_loss = LPIPS().eval() - else: - raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") - self.perceptual_weight = perceptual_weight - - if pixel_loss == "l1": - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - - self.discriminator = NLayerDiscriminator( - input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf, - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - self.n_classes = n_classes - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad( - nll_loss, self.last_layer[0], retain_graph=True - )[0] - g_grads = torch.autograd.grad( - g_loss, self.last_layer[0], retain_graph=True - )[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward( - self, - codebook_loss, - inputs, - reconstructions, - optimizer_idx, - global_step, - last_layer=None, - cond=None, - split="train", - predicted_indices=None, - ): - if not exists(codebook_loss): - codebook_loss = torch.tensor([0.0]).to(inputs.device) - # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss( - inputs.contiguous(), reconstructions.contiguous() - ) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator( - torch.cat((reconstructions.contiguous(), cond), dim=1) - ) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight( - nll_loss, g_loss, last_layer=last_layer - ) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight( - self.disc_factor, - global_step, - threshold=self.discriminator_iter_start, - ) - loss = ( - nll_loss - + d_weight * disc_factor * g_loss - + self.codebook_weight * codebook_loss.mean() - ) - - log = { - "{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if predicted_indices is not None: - assert self.n_classes is not None - with torch.no_grad(): - perplexity, cluster_usage = measure_perplexity( - predicted_indices, self.n_classes - ) - log[f"{split}/perplexity"] = perplexity - log[f"{split}/cluster_usage"] = cluster_usage - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator( - torch.cat((inputs.contiguous().detach(), cond), dim=1) - ) - logits_fake = self.discriminator( - torch.cat((reconstructions.contiguous().detach(), cond), dim=1) - ) - - disc_factor = adopt_weight( - self.disc_factor, - global_step, - threshold=self.discriminator_iter_start, - ) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = { - "{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean(), - } - return d_loss, log diff --git a/invokeai/backend/stable_diffusion/modules/__init__.py b/invokeai/backend/stable_diffusion/modules/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/invokeai/backend/stable_diffusion/x_transformer.py b/invokeai/backend/stable_diffusion/x_transformer.py deleted file mode 100644 index b541d77ee2..0000000000 --- a/invokeai/backend/stable_diffusion/x_transformer.py +++ /dev/null @@ -1,729 +0,0 @@ -"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -from collections import namedtuple -from functools import partial -from inspect import isfunction - -import torch -import torch.nn.functional as F -from einops import rearrange, reduce, repeat -from torch import einsum, nn - -# constants - -DEFAULT_DIM_HEAD = 64 - -Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) - -LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"]) - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - self.emb = nn.Embedding(max_seq_len, dim) - self.init_() - - def init_(self): - nn.init.normal_(self.emb.weight, std=0.02) - - def forward(self, x): - n = torch.arange(x.shape[1], device=x.device) - return self.emb(n)[None, :, :] - - -class FixedPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, x, seq_dim=1, offset=0): - t = ( - torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - + offset - ) - sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) - emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) - return emb[None, :, :] - - -# helpers - - -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def always(val): - def inner(*args, **kwargs): - return val - - return inner - - -def not_equals(val): - def inner(x): - return x != val - - return inner - - -def equals(val): - def inner(x): - return x == val - - return inner - - -def max_neg_value(tensor): - return -torch.finfo(tensor.dtype).max - - -# keyword argument helpers - - -def pick_and_pop(keys, d): - values = list(map(lambda key: d.pop(key), keys)) - return dict(zip(keys, values)) - - -def group_dict_by_key(cond, d): - return_val = [dict(), dict()] - for key in d.keys(): - match = bool(cond(key)) - ind = int(not match) - return_val[ind][key] = d[key] - return (*return_val,) - - -def string_begins_with(prefix, str): - return str.startswith(prefix) - - -def group_by_key_prefix(prefix, d): - return group_dict_by_key(partial(string_begins_with, prefix), d) - - -def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key( - partial(string_begins_with, prefix), d - ) - kwargs_without_prefix = dict( - map( - lambda x: (x[0][len(prefix) :], x[1]), - tuple(kwargs_with_prefix.items()), - ) - ) - return kwargs_without_prefix, kwargs - - -# classes -class Scale(nn.Module): - def __init__(self, value, fn): - super().__init__() - self.value = value - self.fn = fn - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.value, *rest) - - -class Rezero(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.zeros(1)) - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.g, *rest) - - -class ScaleNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.scale = dim**-0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(1)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class RMSNorm(nn.Module): - def __init__(self, dim, eps=1e-8): - super().__init__() - self.scale = dim**-0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class Residual(nn.Module): - def forward(self, x, residual): - return x + residual - - -class GRUGating(nn.Module): - def __init__(self, dim): - super().__init__() - self.gru = nn.GRUCell(dim, dim) - - def forward(self, x, residual): - gated_output = self.gru( - rearrange(x, "b n d -> (b n) d"), - rearrange(residual, "b n d -> (b n) d"), - ) - - return gated_output.reshape_as(x) - - -# feedforward - - -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -# attention. -class Attention(nn.Module): - def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0.0, - on_attn=False, - ): - super().__init__() - if use_entmax15: - raise NotImplementedError( - "Check out entmax activation instead of softmax activation!" - ) - self.scale = dim_head**-0.5 - self.heads = heads - self.causal = causal - self.mask = mask - - inner_dim = dim_head * heads - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_k = nn.Linear(dim, inner_dim, bias=False) - self.to_v = nn.Linear(dim, inner_dim, bias=False) - self.dropout = nn.Dropout(dropout) - - # talking heads - self.talking_heads = talking_heads - if talking_heads: - self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - - # explicit topk sparse attention - self.sparse_topk = sparse_topk - - # entmax - # self.attn_fn = entmax15 if use_entmax15 else F.softmax - self.attn_fn = F.softmax - - # add memory key / values - self.num_mem_kv = num_mem_kv - if num_mem_kv > 0: - self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - - # attention on attention - self.attn_on_attn = on_attn - self.to_out = ( - nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) - if on_attn - else nn.Linear(inner_dim, dim) - ) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None, - ): - b, n, _, h, talking_heads, device = ( - *x.shape, - self.heads, - self.talking_heads, - x.device, - ) - kv_input = default(context, x) - - q_input = x - k_input = kv_input - v_input = kv_input - - if exists(mem): - k_input = torch.cat((mem, k_input), dim=-2) - v_input = torch.cat((mem, v_input), dim=-2) - - if exists(sinusoidal_emb): - # in shortformer, the query would start at a position offset depending on the past cached memory - offset = k_input.shape[-2] - q_input.shape[-2] - q_input = q_input + sinusoidal_emb(q_input, offset=offset) - k_input = k_input + sinusoidal_emb(k_input) - - q = self.to_q(q_input) - k = self.to_k(k_input) - v = self.to_v(v_input) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - input_mask = None - if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) - k_mask = q_mask if not exists(context) else context_mask - k_mask = default( - k_mask, - lambda: torch.ones((b, k.shape[-2]), device=device).bool(), - ) - q_mask = rearrange(q_mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b j -> b () () j") - input_mask = q_mask * k_mask - - if self.num_mem_kv > 0: - mem_k, mem_v = map( - lambda t: repeat(t, "h n d -> b h n d", b=b), - (self.mem_k, self.mem_v), - ) - k = torch.cat((mem_k, k), dim=-2) - v = torch.cat((mem_v, v), dim=-2) - if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) - - dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale - mask_value = max_neg_value(dots) - - if exists(prev_attn): - dots = dots + prev_attn - - pre_softmax_attn = dots - - if talking_heads: - dots = einsum( - "b h i j, h k -> b k i j", dots, self.pre_softmax_proj - ).contiguous() - - if exists(rel_pos): - dots = rel_pos(dots) - - if exists(input_mask): - dots.masked_fill_(~input_mask, mask_value) - del input_mask - - if self.causal: - i, j = dots.shape[-2:] - r = torch.arange(i, device=device) - mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") - mask = F.pad(mask, (j - i, 0), value=False) - dots.masked_fill_(mask, mask_value) - del mask - - if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: - top, _ = dots.topk(self.sparse_topk, dim=-1) - vk = top[..., -1].unsqueeze(-1).expand_as(dots) - mask = dots < vk - dots.masked_fill_(mask, mask_value) - del mask - - attn = self.attn_fn(dots, dim=-1) - post_softmax_attn = attn - - attn = self.dropout(attn) - - if talking_heads: - attn = einsum( - "b h i j, h k -> b k i j", attn, self.post_softmax_proj - ).contiguous() - - out = einsum("b h i j, b h j d -> b h i d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - - intermediates = Intermediates( - pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn, - ) - - return self.to_out(out), intermediates - - -class AttentionLayers(nn.Module): - def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs, - ): - super().__init__() - ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) - attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) - - dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) - - self.dim = dim - self.depth = depth - self.layers = nn.ModuleList([]) - - self.has_pos_emb = position_infused_attn - self.pia_pos_emb = ( - FixedPositionalEmbedding(dim) if position_infused_attn else None - ) - self.rotary_pos_emb = always(None) - - assert ( - rel_pos_num_buckets <= rel_pos_max_distance - ), "number of relative position buckets must be less than the relative position max distance" - self.rel_pos = None - - self.pre_norm = pre_norm - - self.residual_attn = residual_attn - self.cross_residual_attn = cross_residual_attn - - norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm - norm_class = RMSNorm if use_rmsnorm else norm_class - norm_fn = partial(norm_class, dim) - - norm_fn = nn.Identity if use_rezero else norm_fn - branch_fn = Rezero if use_rezero else None - - if cross_attend and not only_cross: - default_block = ("a", "c", "f") - elif cross_attend and only_cross: - default_block = ("c", "f") - else: - default_block = ("a", "f") - - if macaron: - default_block = ("f",) + default_block - - if exists(custom_layers): - layer_types = custom_layers - elif exists(par_ratio): - par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, "par ratio out of range" - default_block = tuple(filter(not_equals("f"), default_block)) - par_attn = par_depth // par_ratio - depth_cut = ( - par_depth * 2 // 3 - ) # 2 / 3 attention layer cutoff suggested by PAR paper - par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert ( - len(default_block) <= par_width - ), "default block is too large for par_ratio" - par_block = default_block + ("f",) * (par_width - len(default_block)) - par_head = par_block * par_attn - layer_types = par_head + ("f",) * (par_depth - len(par_head)) - elif exists(sandwich_coef): - assert ( - sandwich_coef > 0 and sandwich_coef <= depth - ), "sandwich coefficient should be less than the depth" - layer_types = ( - ("a",) * sandwich_coef - + default_block * (depth - sandwich_coef) - + ("f",) * sandwich_coef - ) - else: - layer_types = default_block * depth - - self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals("a"), layer_types))) - - for layer_type in self.layer_types: - if layer_type == "a": - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) - elif layer_type == "c": - layer = Attention(dim, heads=heads, **attn_kwargs) - elif layer_type == "f": - layer = FeedForward(dim, **ff_kwargs) - layer = layer if not macaron else Scale(0.5, layer) - else: - raise Exception(f"invalid layer type {layer_type}") - - if isinstance(layer, Attention) and exists(branch_fn): - layer = branch_fn(layer) - - if gate_residual: - residual_fn = GRUGating(dim) - else: - residual_fn = Residual() - - self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False, - **kwargs, - ): - hiddens = [] - intermediates = [] - prev_attn = None - prev_cross_attn = None - - mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers - - for ind, (layer_type, (norm, block, residual_fn)) in enumerate( - zip(self.layer_types, self.layers) - ): - is_last = ind == (len(self.layers) - 1) - - if layer_type == "a": - hiddens.append(x) - layer_mem = mems.pop(0) - - residual = x - - if self.pre_norm: - x = norm(x) - - if layer_type == "a": - out, inter = block( - x, - mask=mask, - sinusoidal_emb=self.pia_pos_emb, - rel_pos=self.rel_pos, - prev_attn=prev_attn, - mem=layer_mem, - ) - elif layer_type == "c": - out, inter = block( - x, - context=context, - mask=mask, - context_mask=context_mask, - prev_attn=prev_cross_attn, - ) - elif layer_type == "f": - out = block(x) - - x = residual_fn(out, residual) - - if layer_type in ("a", "c"): - intermediates.append(inter) - - if layer_type == "a" and self.residual_attn: - prev_attn = inter.pre_softmax_attn - elif layer_type == "c" and self.cross_residual_attn: - prev_cross_attn = inter.pre_softmax_attn - - if not self.pre_norm and not is_last: - x = norm(x) - - if return_hiddens: - intermediates = LayerIntermediates( - hiddens=hiddens, attn_intermediates=intermediates - ) - - return x, intermediates - - return x - - -class Encoder(AttentionLayers): - def __init__(self, **kwargs): - assert "causal" not in kwargs, "cannot set causality on encoder" - super().__init__(causal=False, **kwargs) - - -class TransformerWrapper(nn.Module): - def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0.0, - emb_dropout=0.0, - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True, - ): - super().__init__() - assert isinstance( - attn_layers, AttentionLayers - ), "attention layers must be one of Encoder or Decoder" - - dim = attn_layers.dim - emb_dim = default(emb_dim, dim) - - self.max_seq_len = max_seq_len - self.max_mem_len = max_mem_len - self.num_tokens = num_tokens - - self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = ( - AbsolutePositionalEmbedding(emb_dim, max_seq_len) - if (use_pos_emb and not attn_layers.has_pos_emb) - else always(0) - ) - self.emb_dropout = nn.Dropout(emb_dropout) - - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() - self.attn_layers = attn_layers - self.norm = nn.LayerNorm(dim) - - self.init_() - - self.to_logits = ( - nn.Linear(dim, num_tokens) - if not tie_embedding - else lambda t: t @ self.token_emb.weight.t() - ) - - # memory tokens (like [cls]) from Memory Transformers paper - num_memory_tokens = default(num_memory_tokens, 0) - self.num_memory_tokens = num_memory_tokens - if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) - - # let funnel encoder know number of memory tokens, if specified - if hasattr(attn_layers, "num_memory_tokens"): - attn_layers.num_memory_tokens = num_memory_tokens - - def init_(self): - nn.init.normal_(self.token_emb.weight, std=0.02) - - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - embedding_manager=None, - **kwargs, - ): - b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - - embedded_x = self.token_emb(x) - - if embedding_manager: - x = embedding_manager(x, embedded_x) - else: - x = embedded_x - - x = x + self.pos_emb(x) - x = self.emb_dropout(x) - - x = self.project_emb(x) - - if num_mem > 0: - mem = repeat(self.memory_tokens, "n d -> b n d", b=b) - x = torch.cat((mem, x), dim=1) - - # auto-handle masking after appending memory tokens - if exists(mask): - mask = F.pad(mask, (num_mem, 0), value=True) - - x, intermediates = self.attn_layers( - x, mask=mask, mems=mems, return_hiddens=True, **kwargs - ) - x = self.norm(x) - - mem, x = x[:, :num_mem], x[:, num_mem:] - - out = self.to_logits(x) if not return_embeddings else x - - if return_mems: - hiddens = intermediates.hiddens - new_mems = ( - list( - map( - lambda pair: torch.cat(pair, dim=-2), - zip(mems, hiddens), - ) - ) - if exists(mems) - else hiddens - ) - new_mems = list( - map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) - ) - return out, new_mems - - if return_attn: - attn_maps = list( - map( - lambda t: t.post_softmax_attn, - intermediates.attn_intermediates, - ) - ) - return out, attn_maps - - return out diff --git a/pyproject.toml b/pyproject.toml index 6e6fc8419b..4c2d903316 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,35 +52,25 @@ dependencies = [ "flask_cors==3.0.10", "flask_socketio==5.3.0", "flaskwebgui==1.0.3", - "getpass_asterisk", "gfpgan==1.3.8", "huggingface-hub>=0.11.1", - "imageio", - "imageio-ffmpeg", - "k-diffusion", # replacing "k-diffusion @ https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip", - "kornia", "npyscreen", "numpy<1.24", "omegaconf", "opencv-python", "picklescan", "pillow", - "pudb", "prompt-toolkit", "pypatchmatch", "pyreadline3", - "python-multipart==0.0.5", - "pytorch-lightning==1.7.7", "realesrgan", "requests==2.28.2", + "rich~=13.3", "safetensors~=0.3.0", "scikit-image>=0.19", "send2trash", - "streamlit", - "taming-transformers-rom1504", "test-tube>=0.7.5", "torch>=1.13.1", - "torch-fidelity", "torchvision>=0.14.1", "torchmetrics", "transformers~=4.26", @@ -95,6 +85,9 @@ dependencies = [ "mkdocs-git-revision-date-localized-plugin", "mkdocs-redirects==1.2.0", ] +"dev" = [ + "pudb", +] "test" = ["pytest>6.0.0", "pytest-cov"] "xformers" = [ "xformers~=0.0.16; sys_platform!='darwin'",