Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config

This commit is contained in:
Brandon Rising
2024-08-19 14:41:28 -04:00
committed by Brandon
parent 4bd7fda694
commit a63f842a13
9 changed files with 170 additions and 66 deletions

View File

@ -1,4 +1,5 @@
import torch
from typing import Literal
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -23,11 +24,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5Encoder: T5EncoderField = InputField(
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5Encoder,
input=Input.Connection,
)
max_seq_len: Literal[256, 512] = InputField(description="Max sequence length for the desired flux model")
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
@ -43,21 +45,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
return ConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell":
max_seq_len = 256
# # elif self.model == "flux-dev":
# # max_seq_len = 512
# else:
# raise ValueError(f"Unknown model: {self.model}")
max_seq_len = self.max_seq_len
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
with (
clip_text_encoder_info as clip_text_encoder,

View File

@ -19,6 +19,7 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.model_manager.config import CheckpointConfigBase
@invocation(
@ -89,7 +90,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],

View File

@ -1,4 +1,5 @@
import copy
import yaml
from time import sleep
from typing import Dict, List, Literal, Optional
@ -16,6 +17,7 @@ from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import CheckpointConfigBase
class ModelIdentifierField(BaseModel):
@ -154,8 +156,9 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(description=FieldDescriptions.vae, title="Max Seq Length")
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
@ -189,12 +192,22 @@ class FluxModelLoaderInvocation(BaseInvocation):
ModelType.VAE,
BaseModelType.Flux,
)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=flux_conf['max_seq_len']
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: