From 03e22c257b4084f4d1c18b442ca96914129f34c2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:37:11 +0300 Subject: [PATCH] Convert conditioning_mode to enum --- .../diffusion/conditioning_data.py | 26 ++++++++++++++----- .../stable_diffusion/diffusion_backend.py | 11 ++++---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 80b671df65..8a52310e6f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import math from dataclasses import dataclass -from typing import List, Optional, Union +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Union import torch -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData +if TYPE_CHECKING: + from invokeai.backend.ip_adapter.ip_adapter import IPAdapter + from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs + @dataclass class BasicConditioningInfo: @@ -96,6 +102,12 @@ class TextConditioningRegions: assert self.masks.shape[1] == len(self.ranges) +class ConditioningMode(Enum): + Both = "both" + Negative = "negative" + Positive = "positive" + + class TextConditioningData: def __init__( self, @@ -124,21 +136,23 @@ class TextConditioningData: assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) - def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode): _, _, h, w = unet_kwargs.sample.shape device = unet_kwargs.sample.device dtype = unet_kwargs.sample.dtype # TODO: combine regions with conditionings - if conditioning_mode == "both": + if conditioning_mode == ConditioningMode.Both: conditionings = [self.uncond_text, self.cond_text] c_regions = [self.uncond_regions, self.cond_regions] - elif conditioning_mode == "positive": + elif conditioning_mode == ConditioningMode.Positive: conditionings = [self.cond_text] c_regions = [self.cond_regions] - else: + elif conditioning_mode == ConditioningMode.Negative: conditionings = [self.uncond_text] c_regions = [self.uncond_regions] + else: + raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}") encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( [c.embeds for c in conditionings] diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index d4c784e1d6..c1035c2a97 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -7,6 +7,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -68,10 +69,10 @@ class StableDiffusionBackend: # This might change in the future as new requirements come up, but for now, # this is the rough plan. if self._sequential_guidance: - ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") - ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") + ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative) + ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive) else: - both_noise_pred = self.run_unet(ctx, ext_manager, "both") + both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) # ext: override apply_cfg @@ -101,9 +102,9 @@ class StableDiffusionBackend: return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str): + def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input - if conditioning_mode == "both": + if conditioning_mode == ConditioningMode.Both: sample = torch.cat([sample] * 2) ctx.unet_kwargs = UNetKwargs(