mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
29 Commits
v4.2.9.dev
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
87261bdbc9 | ||
|
4e4b6c6dbc | ||
|
5e8cf9fb6a | ||
|
c738fe051f | ||
|
29fe1533f2 | ||
|
77090070bd | ||
|
6ba9b1b6b0 | ||
|
c578b8df1e | ||
|
cad9a41433 | ||
|
5fefb3b0f4 | ||
|
5284a870b0 | ||
|
e064377c05 | ||
|
3e569c8312 | ||
|
16825ee6e9 | ||
|
3f5340fa53 | ||
|
f2a1a39b33 | ||
|
326de55d3e | ||
|
b2df909570 | ||
|
026ac36b06 | ||
|
92125e5fd2 | ||
|
c0c139da88 | ||
|
404ad6a7fd | ||
|
fc39086fb4 | ||
|
cd215700fe | ||
|
e97fd85904 | ||
|
0a263fa5b1 | ||
|
fae3836a8d | ||
|
b3d2eb4178 | ||
|
576f1cbb75 |
@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@ -128,6 +130,7 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
|
@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
return pooled_prompt_embeds
|
||||
|
@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
latents = self._run_diffusion(context)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = prepare_latent_img_patches(x)
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
image_seq_len=x.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
|
@ -157,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
model_key = self.model.key
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
if not context.models.exists(model_key):
|
||||
raise ValueError(f"Unknown model: {model_key}")
|
||||
transformer = self._get_model(context, SubModelType.Transformer)
|
||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||
vae = self._get_model(context, SubModelType.VAE)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
match submodel:
|
||||
case SubModelType.Transformer:
|
||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
case SubModelType.VAE:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
SubModelType.VAE,
|
||||
"FLUX.1-schnell_ae",
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
"clip-vit-large-patch14",
|
||||
ModelType.CLIPEmbed,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
self.t5_encoder.name,
|
||||
ModelType.T5Encoder,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||
|
||||
def _pull_model_from_mm(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
submodel: SubModelType,
|
||||
name: str,
|
||||
type: ModelType,
|
||||
base: BaseModelType,
|
||||
):
|
||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||
if len(models) != 1:
|
||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||
else:
|
||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
|
@ -2,13 +2,13 @@
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.0",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
@ -20,8 +20,8 @@
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"fieldName": "t5_encoder"
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
@ -30,12 +30,12 @@
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.3",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
@ -44,31 +44,25 @@
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
||||
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
||||
"name": "FLUX Dev (Quantized)",
|
||||
"base": "flux",
|
||||
"type": "main"
|
||||
}
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
||||
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
||||
"name": "t5_bnb_int8_quantized_encoder",
|
||||
"base": "any",
|
||||
"type": "t5_encoder"
|
||||
}
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 337.09365228062825,
|
||||
"y": 40.63469521079861
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -207,45 +201,45 @@
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
|
@ -111,16 +111,7 @@ def denoise(
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
dtype = model.txt_in.bias.dtype
|
||||
|
||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||
img = img.to(dtype=dtype)
|
||||
img_ids = img_ids.to(dtype=dtype)
|
||||
txt = txt.to(dtype=dtype)
|
||||
txt_ids = txt_ids.to(dtype=dtype)
|
||||
vec = vec.to(dtype=dtype)
|
||||
|
||||
# this is ignored for schnell
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
|
@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
|
@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
@ -1,22 +1,6 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
|
@ -43,6 +43,11 @@ class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
|
@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
@ -7,9 +7,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
@ -696,6 +696,8 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@ -783,6 +785,7 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
@ -791,6 +794,7 @@
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
|
@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
@ -39,10 +40,17 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
|
@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@ -66,6 +67,7 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@ -227,6 +229,7 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
@ -175,12 +175,12 @@ const ModelList = () => {
|
||||
{/* T5 Encoders List */}
|
||||
{isLoadingT5EncoderModels && <FetchingModelsLoader loadingMessage="Loading T5 Encoder Models..." />}
|
||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||
<ModelListWrapper title="T5 Encoder" modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
)}
|
||||
{/* Clip Embed List */}
|
||||
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||
<ModelListWrapper title="Clip Embed" modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
)}
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
@ -188,7 +188,7 @@ const ModelList = () => {
|
||||
)}
|
||||
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Image-to-Image"
|
||||
title={t('modelManager.spandrelImageToImage')}
|
||||
modelList={filteredSpandrelImageToImageModels}
|
||||
key="spandrel-image-to-image"
|
||||
/>
|
||||
|
@ -19,11 +19,10 @@ export const ModelTypeFilter = memo(() => {
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
t5_encoder: 'T5Encoder',
|
||||
clip_embed: 'Clip Embed',
|
||||
t5_encoder: t('modelManager.t5Encoder'),
|
||||
clip_embed: t('modelManager.clipEmbed'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'Clip Vision',
|
||||
spandrel_image_to_image: 'Image-to-Image',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
@ -6,6 +6,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
@ -16,6 +18,8 @@ import {
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@ -49,10 +53,12 @@ import { memo } from 'react';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
|
@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: ClipEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPEmbedModelFieldInputComponent);
|
@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxVAEModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxVAEModelFieldInputComponent);
|
@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
@ -29,10 +31,12 @@ import type {
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
|
||||
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
||||
},
|
||||
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@ -408,6 +418,8 @@ export const {
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodesChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T5EncoderModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zT5EncoderModelFieldType,
|
||||
zCLIPEmbedModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
|
||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregio
|
||||
// #endregion
|
||||
|
||||
// #region FluxVAEModelField
|
||||
|
||||
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxVAEModelFieldValue,
|
||||
});
|
||||
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxVAEModelFieldValue,
|
||||
});
|
||||
|
||||
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region CLIPEmbedModelField
|
||||
|
||||
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
|
||||
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
|
||||
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zT5EncoderModelFieldInputInstance,
|
||||
zFluxVAEModelFieldInputInstance,
|
||||
zCLIPEmbedModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zT5EncoderModelFieldInputTemplate,
|
||||
zFluxVAEModelFieldInputTemplate,
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
T5EncoderModelField: undefined,
|
||||
FluxVAEModelField: undefined,
|
||||
CLIPEmbedModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
@ -9,6 +10,7 @@ import type {
|
||||
FieldType,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxVAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
@ -52,3 +53,4 @@ export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToIm
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
|
||||
|
@ -5770,8 +5770,22 @@ export type components = {
|
||||
use_cache?: boolean;
|
||||
/** @description Flux model (Transformer) to load */
|
||||
model: components["schemas"]["ModelIdentifierField"];
|
||||
/** @description T5 tokenizer and text encoder */
|
||||
t5_encoder: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* T5 Encoder
|
||||
* @description T5 tokenizer and text encoder
|
||||
*/
|
||||
t5_encoder_model: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* CLIP Embed
|
||||
* @description CLIP Embed loader
|
||||
*/
|
||||
clip_embed_model: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* VAE
|
||||
* @description VAE model to load
|
||||
* @default null
|
||||
*/
|
||||
vae_model?: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* type
|
||||
* @default flux_model_loader
|
||||
@ -15097,7 +15111,7 @@ export type components = {
|
||||
* used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
* @enum {string}
|
||||
*/
|
||||
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
/** UNetField */
|
||||
UNetField: {
|
||||
/** @description Info to load unet submodel */
|
||||
|
@ -51,7 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
||||
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
@ -82,6 +82,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
|
||||
return config.type === 'vae';
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
return config.type === 'controlnet';
|
||||
};
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "4.2.8post1"
|
||||
__version__ = "4.2.9rc1"
|
||||
|
@ -130,8 +130,6 @@ dependencies = [
|
||||
|
||||
[project.scripts]
|
||||
"invokeai-web" = "invokeai.app.run_app:run_app"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
Loading…
x
Reference in New Issue
Block a user