mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Convert conditioning_mode to enum
This commit is contained in:
parent
ae6d4fbc78
commit
03e22c257b
@ -1,12 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
@ -96,6 +102,12 @@ class TextConditioningRegions:
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class ConditioningMode(Enum):
|
||||
Both = "both"
|
||||
Negative = "negative"
|
||||
Positive = "positive"
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
@ -124,21 +136,23 @@ class TextConditioningData:
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
|
||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
|
||||
# TODO: combine regions with conditionings
|
||||
if conditioning_mode == "both":
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
conditionings = [self.uncond_text, self.cond_text]
|
||||
c_regions = [self.uncond_regions, self.cond_regions]
|
||||
elif conditioning_mode == "positive":
|
||||
elif conditioning_mode == ConditioningMode.Positive:
|
||||
conditionings = [self.cond_text]
|
||||
c_regions = [self.cond_regions]
|
||||
else:
|
||||
elif conditioning_mode == ConditioningMode.Negative:
|
||||
conditionings = [self.uncond_text]
|
||||
c_regions = [self.uncond_regions]
|
||||
else:
|
||||
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
[c.embeds for c in conditionings]
|
||||
|
@ -7,6 +7,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
@ -68,10 +69,10 @@ class StableDiffusionBackend:
|
||||
# This might change in the future as new requirements come up, but for now,
|
||||
# this is the rough plan.
|
||||
if self._sequential_guidance:
|
||||
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative")
|
||||
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive")
|
||||
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
|
||||
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
|
||||
else:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, "both")
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
@ -101,9 +102,9 @@ class StableDiffusionBackend:
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str):
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
if conditioning_mode == "both":
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
sample = torch.cat([sample] * 2)
|
||||
|
||||
ctx.unet_kwargs = UNetKwargs(
|
||||
|
Loading…
Reference in New Issue
Block a user