mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Convert conditioning_mode to enum
This commit is contained in:
parent
ae6d4fbc78
commit
03e22c257b
@ -1,12 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
@ -96,6 +102,12 @@ class TextConditioningRegions:
|
|||||||
assert self.masks.shape[1] == len(self.ranges)
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningMode(Enum):
|
||||||
|
Both = "both"
|
||||||
|
Negative = "negative"
|
||||||
|
Positive = "positive"
|
||||||
|
|
||||||
|
|
||||||
class TextConditioningData:
|
class TextConditioningData:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -124,21 +136,23 @@ class TextConditioningData:
|
|||||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
|
||||||
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
|
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||||
_, _, h, w = unet_kwargs.sample.shape
|
_, _, h, w = unet_kwargs.sample.shape
|
||||||
device = unet_kwargs.sample.device
|
device = unet_kwargs.sample.device
|
||||||
dtype = unet_kwargs.sample.dtype
|
dtype = unet_kwargs.sample.dtype
|
||||||
|
|
||||||
# TODO: combine regions with conditionings
|
# TODO: combine regions with conditionings
|
||||||
if conditioning_mode == "both":
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
conditionings = [self.uncond_text, self.cond_text]
|
conditionings = [self.uncond_text, self.cond_text]
|
||||||
c_regions = [self.uncond_regions, self.cond_regions]
|
c_regions = [self.uncond_regions, self.cond_regions]
|
||||||
elif conditioning_mode == "positive":
|
elif conditioning_mode == ConditioningMode.Positive:
|
||||||
conditionings = [self.cond_text]
|
conditionings = [self.cond_text]
|
||||||
c_regions = [self.cond_regions]
|
c_regions = [self.cond_regions]
|
||||||
else:
|
elif conditioning_mode == ConditioningMode.Negative:
|
||||||
conditionings = [self.uncond_text]
|
conditionings = [self.uncond_text]
|
||||||
c_regions = [self.uncond_regions]
|
c_regions = [self.uncond_regions]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||||
|
|
||||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
[c.embeds for c in conditionings]
|
[c.embeds for c in conditionings]
|
||||||
|
@ -7,6 +7,7 @@ from tqdm.auto import tqdm
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
@ -68,10 +69,10 @@ class StableDiffusionBackend:
|
|||||||
# This might change in the future as new requirements come up, but for now,
|
# This might change in the future as new requirements come up, but for now,
|
||||||
# this is the rough plan.
|
# this is the rough plan.
|
||||||
if self._sequential_guidance:
|
if self._sequential_guidance:
|
||||||
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative")
|
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
|
||||||
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive")
|
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
|
||||||
else:
|
else:
|
||||||
both_noise_pred = self.run_unet(ctx, ext_manager, "both")
|
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
|
|
||||||
# ext: override apply_cfg
|
# ext: override apply_cfg
|
||||||
@ -101,9 +102,9 @@ class StableDiffusionBackend:
|
|||||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
|
|
||||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str):
|
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||||
sample = ctx.latent_model_input
|
sample = ctx.latent_model_input
|
||||||
if conditioning_mode == "both":
|
if conditioning_mode == ConditioningMode.Both:
|
||||||
sample = torch.cat([sample] * 2)
|
sample = torch.cat([sample] * 2)
|
||||||
|
|
||||||
ctx.unet_kwargs = UNetKwargs(
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
|
Loading…
Reference in New Issue
Block a user