diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 6094c868d9..11d9d40047 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -71,6 +71,9 @@ class FieldDescriptions: safe_mode = "Whether or not to use safe mode" scribble_mode = "Whether or not to use scribble mode" scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) num_1 = "The first number" num_2 = "The second number" mask = "The mask to use for the operation" diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 69e4ffcaae..71d5ba779c 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -233,7 +233,7 @@ class SDXLPromptInvocationBase: dtype_for_device_getter=torch_dtype, truncate_long_prompts=True, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip - requires_pooled=True, + requires_pooled=get_pooled, ) conjunction = Compel.parse_prompt_string(prompt) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index e12cc18f42..f65a95999d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,6 +4,7 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops +import numpy as np import torch import torchvision.transforms as T from diffusers.image_processor import VaeImageProcessor @@ -720,3 +721,81 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents, seed=None) + + +@title("Blend Latents") +@tags("latents", "blend") +class BlendLatentsInvocation(BaseInvocation): + """Blend two latents using a given alpha. Latents must have same size.""" + + type: Literal["lblend"] = "lblend" + + # Inputs + latents_a: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + latents_b: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents_a = context.services.latents.get(self.latents_a.latents_name) + latents_b = context.services.latents.get(self.latents_b.latents_name) + + if latents_a.shape != latents_b.shape: + raise "Latents to blend must be the same size." + + # TODO: + device = choose_torch_device() + + def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + """ + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + """ + inputs_are_torch = False + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + v0 = v0.detach().cpu().numpy() + if not isinstance(v1, np.ndarray): + inputs_are_torch = True + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(device) + + return v2 + + # blend + blended_latents = slerp(self.alpha, latents_a, latents_b) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + blended_latents = blended_latents.to("cpu") + torch.cuda.empty_cache() + + name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) + context.services.latents.save(name, blended_latents) + return build_latents_output(latents_name=name, latents=blended_latents) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 5fd3669911..8118e28abb 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -20,11 +20,36 @@ import re from contextlib import nullcontext from io import BytesIO -from typing import Optional, Union from pathlib import Path +from typing import Optional, Union import requests import torch +from diffusers.models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available, is_omegaconf_available +from diffusers.utils.import_utils import BACKENDS_MAPPING +from picklescan.scanner import scan_file_path from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -37,35 +62,8 @@ from transformers import ( CLIPVisionModelWithProjection, ) -from diffusers.models import ( - AutoencoderKL, - ControlNetModel, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available -from diffusers.utils.import_utils import BACKENDS_MAPPING -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - -from invokeai.backend.util.logging import InvokeAILogger from invokeai.app.services.config import InvokeAIAppConfig - -from picklescan.scanner import scan_file_path +from invokeai.backend.util.logging import InvokeAILogger from .models import BaseModelType, ModelVariantType try: @@ -1221,9 +1219,6 @@ def download_from_original_stable_diffusion_ckpt( raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) if from_safetensors: - if not is_safetensors_available(): - raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - from safetensors.torch import load_file as safe_load checkpoint = safe_load(checkpoint_path, device="cpu") @@ -1662,9 +1657,6 @@ def download_controlnet_from_original_ckpt( from omegaconf import OmegaConf if from_safetensors: - if not is_safetensors_available(): - raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - from safetensors import safe_open checkpoint = {} @@ -1741,7 +1733,7 @@ def convert_ckpt_to_diffusers( pipe.save_pretrained( dump_path, - safe_serialization=use_safetensors and is_safetensors_available(), + safe_serialization=use_safetensors, ) @@ -1757,7 +1749,4 @@ def convert_controlnet_to_diffusers( """ pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) - pipe.save_pretrained( - dump_path, - safe_serialization=is_safetensors_available(), - ) + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index cf7622a9aa..f5dc11b27b 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -5,7 +5,6 @@ from typing import Optional import safetensors import torch -from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf from invokeai.app.services.config import InvokeAIAppConfig @@ -175,5 +174,5 @@ def _convert_vae_ckpt_and_cache( vae_config=config, image_size=image_size, ) - vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available()) + vae_model.save_pretrained(output_path, safe_serialization=True) return output_path diff --git a/pyproject.toml b/pyproject.toml index 980cf498b7..02e53f066a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "controlnet-aux>=0.0.6", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets", - "diffusers[torch]~=0.19.3", + "diffusers[torch]~=0.20.0", "dnspython~=2.4.0", "dynamicprompts", "easing-functions", @@ -49,7 +49,7 @@ dependencies = [ "fastapi==0.88.0", "fastapi-events==0.8.0", "fastapi-socketio==0.0.10", - "huggingface-hub>=0.11.1", + "huggingface-hub~=0.16.4", "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model