mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from typing import Literal
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
@ -23,11 +24,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5Encoder: T5EncoderField = InputField(
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5Encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
max_seq_len: Literal[256, 512] = InputField(description="Max sequence length for the desired flux model")
|
||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||
|
||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||
@ -43,21 +45,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO: Determine the T5 max sequence length based on the model.
|
||||
# if self.model == "flux-schnell":
|
||||
max_seq_len = 256
|
||||
# # elif self.model == "flux-dev":
|
||||
# # max_seq_len = 512
|
||||
# else:
|
||||
# raise ValueError(f"Unknown model: {self.model}")
|
||||
max_seq_len = self.max_seq_len
|
||||
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
|
@ -19,6 +19,7 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -89,7 +90,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
|
||||
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import yaml
|
||||
from time import sleep
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
@ -16,6 +17,7 @@ from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@ -154,8 +156,9 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(description=FieldDescriptions.vae, title="Max Seq Length")
|
||||
|
||||
|
||||
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
|
||||
@ -189,12 +192,22 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
try:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
except:
|
||||
raise
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=flux_conf['max_seq_len']
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
|
Reference in New Issue
Block a user