Convert conditioning_mode to enum

This commit is contained in:
Sergey Borisov 2024-07-17 03:37:11 +03:00
parent ae6d4fbc78
commit 03e22c257b
2 changed files with 26 additions and 11 deletions

View File

@ -1,12 +1,18 @@
from __future__ import annotations
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
if TYPE_CHECKING:
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
@ -96,6 +102,12 @@ class TextConditioningRegions:
assert self.masks.shape[1] == len(self.ranges) assert self.masks.shape[1] == len(self.ranges)
class ConditioningMode(Enum):
Both = "both"
Negative = "negative"
Positive = "positive"
class TextConditioningData: class TextConditioningData:
def __init__( def __init__(
self, self,
@ -124,21 +136,23 @@ class TextConditioningData:
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs, conditioning_mode): def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
_, _, h, w = unet_kwargs.sample.shape _, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype dtype = unet_kwargs.sample.dtype
# TODO: combine regions with conditionings # TODO: combine regions with conditionings
if conditioning_mode == "both": if conditioning_mode == ConditioningMode.Both:
conditionings = [self.uncond_text, self.cond_text] conditionings = [self.uncond_text, self.cond_text]
c_regions = [self.uncond_regions, self.cond_regions] c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == "positive": elif conditioning_mode == ConditioningMode.Positive:
conditionings = [self.cond_text] conditionings = [self.cond_text]
c_regions = [self.cond_regions] c_regions = [self.cond_regions]
else: elif conditioning_mode == ConditioningMode.Negative:
conditionings = [self.uncond_text] conditionings = [self.uncond_text]
c_regions = [self.uncond_regions] c_regions = [self.uncond_regions]
else:
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
[c.embeds for c in conditionings] [c.embeds for c in conditionings]

View File

@ -7,6 +7,7 @@ from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@ -68,10 +69,10 @@ class StableDiffusionBackend:
# This might change in the future as new requirements come up, but for now, # This might change in the future as new requirements come up, but for now,
# this is the rough plan. # this is the rough plan.
if self._sequential_guidance: if self._sequential_guidance:
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative") ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive") ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
else: else:
both_noise_pred = self.run_unet(ctx, ext_manager, "both") both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
# ext: override apply_cfg # ext: override apply_cfg
@ -101,9 +102,9 @@ class StableDiffusionBackend:
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str): def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input sample = ctx.latent_model_input
if conditioning_mode == "both": if conditioning_mode == ConditioningMode.Both:
sample = torch.cat([sample] * 2) sample = torch.cat([sample] * 2)
ctx.unet_kwargs = UNetKwargs( ctx.unet_kwargs = UNetKwargs(