mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: basic wrapper for generating sd3 images
This commit is contained in:
parent
554809c647
commit
f65d50a4dd
@ -12,14 +12,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
@ -60,6 +60,15 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class SD3CLIPField(BaseModel):
|
||||
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
|
||||
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
|
||||
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
|
||||
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
|
||||
tokenizer_3: ModelIdentifierField = Field(description="Info to load tokenizer 3 submodel")
|
||||
text_encoder_3: ModelIdentifierField = Field(description="Info to load text_encoder 3 submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
@ -1,8 +1,47 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, TransformerField, VAEField
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
from git import Optional
|
||||
from pydantic import field_validator
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.denoise_latents import get_scheduler
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
||||
transformer_info: Optional[LoadedModel] = None
|
||||
tokenizer_1_info: Optional[LoadedModel] = None
|
||||
tokenizer_2_info: Optional[LoadedModel] = None
|
||||
tokenizer_3_info: Optional[LoadedModel] = None
|
||||
text_encoder_1_info: Optional[LoadedModel] = None
|
||||
text_encoder_2_info: Optional[LoadedModel] = None
|
||||
text_encoder_3_info: Optional[LoadedModel] = None
|
||||
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
@ -10,9 +49,7 @@ class SD3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Stable Diffuion 3 base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3")
|
||||
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@ -30,18 +67,154 @@ class SD3ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SD3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0),
|
||||
clip=SD3CLIPField(
|
||||
tokenizer_1=tokenizer_1,
|
||||
text_encoder_1=text_encoder_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
text_encoder_3=text_encoder_3,
|
||||
),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
|
||||
)
|
||||
class StableDiffusion3Invocation(BaseInvocation):
|
||||
"""Generates an image using Stable Diffusion 3."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
ui_order=0,
|
||||
)
|
||||
clip: SD3CLIPField = InputField(
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
ui_order=1,
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
ui_order=2,
|
||||
)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler_f",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
positive_prompt: str = InputField(default="", title="Positive Prompt")
|
||||
negative_prompt: str = InputField(default="", title="Negative Prompt")
|
||||
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
|
||||
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
)
|
||||
width: int = InputField(
|
||||
default=1024,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
default=1024,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v: int):
|
||||
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
global sd3_pipeline, transformer_info, tokenizer_1_info, tokenizer_2_info, tokenizer_3_info, text_encoder_1_info, text_encoder_2_info, text_encoder_3_info
|
||||
|
||||
if not transformer_info:
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
if not tokenizer_1_info:
|
||||
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
|
||||
if not tokenizer_2_info:
|
||||
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
|
||||
if not tokenizer_3_info:
|
||||
tokenizer_3_info = context.models.load(self.clip.tokenizer_3)
|
||||
if not text_encoder_1_info:
|
||||
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
|
||||
if not text_encoder_2_info:
|
||||
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
|
||||
if not text_encoder_3_info:
|
||||
text_encoder_3_info = context.models.load(self.clip.text_encoder_3)
|
||||
|
||||
with (
|
||||
tokenizer_1_info as tokenizer_1,
|
||||
tokenizer_2_info as tokenizer_2,
|
||||
tokenizer_3_info as tokenizer_3,
|
||||
text_encoder_1_info as text_encoder_1,
|
||||
text_encoder_2_info as text_encoder_2,
|
||||
text_encoder_3_info as text_encoder_3,
|
||||
transformer_info as transformer,
|
||||
):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
||||
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
||||
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||
assert isinstance(tokenizer_1, CLIPTokenizer)
|
||||
assert isinstance(tokenizer_2, CLIPTokenizer)
|
||||
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.transformer.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
if not isinstance(sd3_pipeline, StableDiffusion3Pipeline):
|
||||
sd3_pipeline = StableDiffusion3Pipeline(
|
||||
transformer=transformer,
|
||||
vae=FakeVae(),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
sd3_pipeline.components["scheduler"] = scheduler
|
||||
sd3_pipeline.to(TorchDevice.choose_torch_device().type)
|
||||
|
||||
results = sd3_pipeline(
|
||||
self.positive_prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
num_inference_steps=self.steps,
|
||||
guidance_scale=self.guidance_scale,
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
latents = cast(torch.Tensor, results.images[0])
|
||||
latents = latents.unsqueeze(0)
|
||||
|
||||
latents_name = context.tensors.save(latents)
|
||||
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)
|
||||
|
@ -7,6 +7,7 @@ from diffusers import (
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
@ -29,6 +30,7 @@ SCHEDULER_MAP = {
|
||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||
"euler_f": (FlowMatchEulerDiscreteScheduler, {}),
|
||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||
|
@ -631,6 +631,7 @@ export const schema = {
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
@ -694,6 +695,7 @@ export const schema = {
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
|
@ -47,6 +47,7 @@ export const zSchedulerField = z.enum([
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
'tcd',
|
||||
|
@ -39,6 +39,7 @@ export const MODEL_TYPES = [
|
||||
'TransformerField',
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'SD3CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
@ -49,6 +50,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
CLIPField: 'green.500',
|
||||
SD3CLIPField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
|
@ -80,6 +80,7 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
||||
{ value: 'heun_k', label: 'Heun Karras' },
|
||||
{ value: 'lms_k', label: 'LMS Karras' },
|
||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||
{ value: 'euler_f', label: 'Euler Flow Match' },
|
||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||
{ value: 'lcm', label: 'LCM' },
|
||||
{ value: 'tcd', label: 'TCD' },
|
||||
|
File diff suppressed because one or more lines are too long
@ -53,6 +53,7 @@ dependencies = [
|
||||
"torchsde==0.2.6",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"sentencepiece",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
|
Loading…
Reference in New Issue
Block a user