Compare commits

...

25 Commits

Author SHA1 Message Date
Ryan Dick
87261bdbc9
FLUX memory management improvements (#6791)
## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-29 15:17:45 -04:00
Ryan Dick
4e4b6c6dbc Tidy variable management and dtype handling in FluxTextToImageInvocation. 2024-08-29 19:08:18 +00:00
Ryan Dick
5e8cf9fb6a Remove hack to clear cache from the FluxTextToImageInvocation. We now clear the cache based on the on-disk model size. 2024-08-29 19:08:18 +00:00
Ryan Dick
c738fe051f Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time. 2024-08-29 19:08:18 +00:00
Ryan Dick
29fe1533f2 Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. 2024-08-29 19:08:18 +00:00
Ryan Dick
77090070bd Check the size of a model on disk and make room for it in the cache before loading it. 2024-08-29 19:08:18 +00:00
Ryan Dick
6ba9b1b6b0 Tidy up GIG -> GB and remove unused GIG constant. 2024-08-29 19:08:18 +00:00
Ryan Dick
c578b8df1e Improve ModelCache docs. 2024-08-29 19:08:18 +00:00
Ryan Dick
cad9a41433 Remove unused MOdelCache.exists(...) function. 2024-08-29 19:08:18 +00:00
Ryan Dick
5fefb3b0f4 Remove unused param from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
5284a870b0 Remove unused constructor params from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
e064377c05 Remove default model cache sizes from model_cache_default.py. These defaults were misleading, because the config defaults take precedence over them. 2024-08-29 19:08:18 +00:00
Mary Hipp
3e569c8312 feat(ui): add fields for CLIP embed models and Flux VAE models in workflows 2024-08-29 11:52:51 -04:00
maryhipp
16825ee6e9 feat(nodes): bump version of flux model node, update default workflow 2024-08-29 11:52:51 -04:00
Mary Hipp
3f5340fa53 feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names 2024-08-29 11:52:51 -04:00
chainchompa
f2a1a39b33
Add selectedStylePreset to app parameters (#6787)
## Summary
- Add selectedStylePreset to app parameters
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-28 10:53:07 -04:00
chainchompa
326de55d3e remove api changes and only preselect style preset 2024-08-28 09:53:29 -04:00
chainchompa
b2df909570 added selectedStylePreset to preload presets when app loads 2024-08-28 09:50:44 -04:00
chainchompa
026ac36b06 Revert "added selectedStylePreset to preload presets when app loads"
This reverts commit e97fd85904e32ebe0a35cc66e75f010970e7dc63.
2024-08-28 09:44:08 -04:00
chainchompa
92125e5fd2 bug fixes 2024-08-27 16:13:38 -04:00
chainchompa
c0c139da88 formatting ruff 2024-08-27 15:46:51 -04:00
chainchompa
404ad6a7fd cleanup 2024-08-27 15:42:42 -04:00
chainchompa
fc39086fb4 call stylePresetSelected 2024-08-27 15:34:31 -04:00
chainchompa
cd215700fe added route for selecting style preset 2024-08-27 15:34:07 -04:00
chainchompa
e97fd85904 added selectedStylePreset to preload presets when app loads 2024-08-27 15:33:24 -04:00
25 changed files with 465 additions and 226 deletions

View File

@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
SDXLRefinerModel = "SDXLRefinerModelField" SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField" ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField" VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField" LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField" ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField" IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField" T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField" T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField" SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion # endregion
@ -128,6 +130,7 @@ class FieldDescriptions:
noise = "Noise tensor" noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder" t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)" unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer" transformer = "Transformer"
vae = "VAE" vae = "VAE"

View File

@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput: def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context) # Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
) )
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name) return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt_embeds = t5_encoder(prompt) prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with ( with (
clip_text_encoder_info as clip_text_encoder, clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer, clip_tokenizer_info as clip_tokenizer,
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
pooled_prompt_embeds = clip_encoder(prompt) pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds return pooled_prompt_embeds

View File

@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data. latents = self._run_diffusion(context)
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents) image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _run_diffusion( def _run_diffusion(
self, self,
context: InvocationContext, context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
): ):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16 inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise. # Prepare input noise.
x = get_noise( x = get_noise(
num_samples=1, num_samples=1,
@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed, seed=self.seed,
) )
img, img_ids = prepare_latent_img_patches(x) x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
image_seq_len=img.shape[1], image_seq_len=x.shape[1],
shift=not is_schnell, shift=not is_schnell,
) )
bs, t5_seq_len, _ = t5_embeddings.shape bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, Flux) assert isinstance(transformer, Flux)
@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x = denoise( x = denoise(
model=transformer, model=transformer,
img=img, img=x,
img_ids=img_ids, img_ids=img_ids,
txt=t5_embeddings, txt=t5_embeddings,
txt_ids=txt_ids, txt_ids=txt_ids,

View File

@ -157,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model", title="Flux Main Model",
tags=["model", "flux"], tags=["model", "flux"],
category="model", category="model",
version="1.0.3", version="1.0.4",
classification=Classification.Prototype, classification=Classification.Prototype,
) )
class FluxModelLoaderInvocation(BaseInvocation): class FluxModelLoaderInvocation(BaseInvocation):
@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
input=Input.Direct, input=Input.Direct,
) )
t5_encoder: ModelIdentifierField = InputField( t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
ui_type=UIType.T5EncoderModel, )
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct, input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
) )
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer) transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase) assert isinstance(transformer_config, CheckpointConfigBase)
@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
max_seq_len=max_seq_lengths[transformer_config.config_path], max_seq_len=max_seq_lengths[transformer_config.config_path],
) )
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation( @invocation(
"main_model_loader", "main_model_loader",

View File

@ -2,13 +2,13 @@
"name": "FLUX Text to Image", "name": "FLUX Text to Image",
"author": "InvokeAI", "author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", "description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.0", "version": "1.0.4",
"contact": "", "contact": "",
"tags": "text2image, flux", "tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", "notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [ "exposedFields": [
{ {
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model" "fieldName": "model"
}, },
{ {
@ -20,8 +20,8 @@
"fieldName": "num_steps" "fieldName": "num_steps"
}, },
{ {
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder" "fieldName": "t5_encoder_model"
} }
], ],
"meta": { "meta": {
@ -30,12 +30,12 @@
}, },
"nodes": [ "nodes": [
{ {
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation", "type": "invocation",
"data": { "data": {
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader", "type": "flux_model_loader",
"version": "1.0.3", "version": "1.0.4",
"label": "", "label": "",
"notes": "", "notes": "",
"isOpen": true, "isOpen": true,
@ -44,31 +44,25 @@
"inputs": { "inputs": {
"model": { "model": {
"name": "model", "name": "model",
"label": "Model (Starter Models can be found in Model Manager)", "label": ""
"value": {
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
"name": "FLUX Dev (Quantized)",
"base": "flux",
"type": "main"
}
}, },
"t5_encoder": { "t5_encoder_model": {
"name": "t5_encoder", "name": "t5_encoder_model",
"label": "T 5 Encoder (Starter Models can be found in Model Manager)", "label": ""
"value": { },
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5", "clip_embed_model": {
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f", "name": "clip_embed_model",
"name": "t5_bnb_int8_quantized_encoder", "label": ""
"base": "any", },
"type": "t5_encoder" "vae_model": {
} "name": "vae_model",
"label": ""
} }
} }
}, },
"position": { "position": {
"x": 337.09365228062825, "x": 381.1882713063478,
"y": 40.63469521079861 "y": -95.89663532854017
} }
}, },
{ {
@ -207,45 +201,45 @@
], ],
"edges": [ "edges": [
{ {
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len", "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default", "type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len", "sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len" "targetHandle": "t5_max_seq_len"
}, },
{ {
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae", "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default", "type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8", "target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae", "sourceHandle": "vae",
"targetHandle": "vae" "targetHandle": "vae"
}, },
{ {
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer", "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default", "type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder", "sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder" "targetHandle": "t5_encoder"
}, },
{ {
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip", "id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default", "type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", "source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip", "sourceHandle": "clip",
"targetHandle": "clip" "targetHandle": "clip"
}, },
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{ {
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning", "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default", "type": "default",

View File

@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None], step_callback: Callable[[], None],
guidance: float = 4.0, guidance: float = 4.0,
): ):
dtype = model.txt_in.bias.dtype # guidance_vec is ignored for schnell.
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs) img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids. # Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids return img, img_ids

View File

@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
pass pass
config.path = str(self._get_model_path(config)) config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type) loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put( self._ram_cache.put(

View File

@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
""" """
pass pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod @abstractmethod
def cache_size(self) -> int: def cache_size(self) -> int:
"""Get the total size of the models currently cached.""" """Get the total size of the models currently cached."""

View File

@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright # TODO: Add Stalker's proper name to copyright
""" """ """
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc import gc
import math import math
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs # Size of a GB in bytes.
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously GB = 2**30
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes. # Size of a MB in bytes.
MB = 2**20 MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]): class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase.""" """A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__( def __init__(
self, self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"), execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True, lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False, log_memory_usage: bool = False,
logger: Optional[Logger] = None, logger: Optional[Logger] = None,
): ):
""" """
Initialize the model RAM cache. Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
""" """
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device self._execution_device: torch.device = execution_device
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size total += cache_record.size
return total return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put( def put(
self, self,
key: str, key: str,
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats # more stats
if self.stats: if self.stats:
stats_name = stats_name or key stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG) self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models) self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max( self.stats.loaded_model_sizes[stats_name] = max(
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key return model_key
def offload_unlocked_models(self, size_required: int) -> None: def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM.""" """Offload models from the execution_device to make room for size_required.
reserved = self._max_vram_cache_size * GIG
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved: if vram_in_use <= reserved:
break break
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug( self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
) )
TorchDevice.empty_cache() TorchDevice.empty_cache()
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug( self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to" f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to" f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's" f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:" " estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n" f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GIG) ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0 in_ram_models = 0
in_vram_models = 0 in_vram_models = 0
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
) )
def make_room(self, size: int) -> None: def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.""" """Make enough room in the cache to accommodate a new model of indicated size.
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1 Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size() current_size = self.cache_size()
if current_size + bytes_needed > maximum_size: if current_size + bytes_needed > maximum_size:
self.logger.debug( self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB" f" {(bytes_needed/GB):.2f} GB"
) )
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked: if not cache_entry.locked:
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
) )
current_size -= cache_entry.size current_size -= cache_entry.size
models_cleared += 1 models_cleared += 1

View File

@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None) scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None) # Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API. # rather than raising an exception to correctly implement this API.
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
) )
self.bias = bias if bias is None else torch.nn.Parameter(bias) self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit( def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""

View File

@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor clip_embeds: torch.Tensor
t5_embeds: torch.Tensor t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass @dataclass
class ConditioningFieldData: class ConditioningFieldData:

View File

@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
""" """
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import GIG, Chdir, directory_size from invokeai.backend.util.util import Chdir, directory_size
__all__ = [ __all__ = [
"GIG",
"directory_size", "directory_size",
"Chdir", "Chdir",
"InvokeAILogger", "InvokeAILogger",

View File

@ -7,9 +7,6 @@ from pathlib import Path
from PIL import Image from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str: def slugify(value: str, allow_unicode: bool = False) -> str:
""" """

View File

@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal'; import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast'; import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal'; import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs'; import InvokeTabs from 'features/ui/components/InvokeTabs';
@ -39,10 +40,17 @@ interface Props {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
}; };
selectedWorkflowId?: string; selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined; destination?: InvokeTabName | undefined;
} }
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => { const App = ({
config = DEFAULT_CONFIG,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
}: Props) => {
const language = useAppSelector(languageSelector); const language = useAppSelector(languageSelector);
const logger = useLogger('system'); const logger = useLogger('system');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
} }
}, [selectedWorkflowId, getAndLoadWorkflow]); }, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (selectedStylePresetId) {
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
}
}, [dispatch, selectedStylePresetId]);
useEffect(() => { useEffect(() => {
if (destination) { if (destination) {
dispatch(setActiveTab(destination)); dispatch(setActiveTab(destination));

View File

@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
}; };
selectedWorkflowId?: string; selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName; destination?: InvokeTabName;
customStarUi?: CustomStarUi; customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>; socketOptions?: Partial<ManagerOptions & SocketOptions>;
@ -66,6 +67,7 @@ const InvokeAIUI = ({
queueId, queueId,
selectedImage, selectedImage,
selectedWorkflowId, selectedWorkflowId,
selectedStylePresetId,
destination, destination,
customStarUi, customStarUi,
socketOptions, socketOptions,
@ -227,6 +229,7 @@ const InvokeAIUI = ({
config={config} config={config}
selectedImage={selectedImage} selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId} selectedWorkflowId={selectedWorkflowId}
selectedStylePresetId={selectedStylePresetId}
destination={destination} destination={destination}
/> />
</AppDndContext> </AppDndContext>

View File

@ -6,6 +6,8 @@ import {
isBoardFieldInputTemplate, isBoardFieldInputTemplate,
isBooleanFieldInputInstance, isBooleanFieldInputInstance,
isBooleanFieldInputTemplate, isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isColorFieldInputInstance, isColorFieldInputInstance,
isColorFieldInputTemplate, isColorFieldInputTemplate,
isControlNetModelFieldInputInstance, isControlNetModelFieldInputInstance,
@ -16,6 +18,8 @@ import {
isFloatFieldInputTemplate, isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance, isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate, isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldInputInstance, isImageFieldInputInstance,
isImageFieldInputTemplate, isImageFieldInputTemplate,
isIntegerFieldInputInstance, isIntegerFieldInputInstance,
@ -49,10 +53,12 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent'; import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent'; import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent'; import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent'; import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) { if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
import type { ClipEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useClipEmbedModels();
const _onChange = useCallback(
(value: ClipEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPEmbedModelFieldInputComponent);

View File

@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxVAEModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(FluxVAEModelFieldInputComponent);

View File

@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { import type {
BoardFieldValue, BoardFieldValue,
BooleanFieldValue, BooleanFieldValue,
CLIPEmbedModelFieldValue,
ColorFieldValue, ColorFieldValue,
ControlNetModelFieldValue, ControlNetModelFieldValue,
EnumFieldValue, EnumFieldValue,
FieldValue, FieldValue,
FloatFieldValue, FloatFieldValue,
FluxVAEModelFieldValue,
ImageFieldValue, ImageFieldValue,
IntegerFieldValue, IntegerFieldValue,
IPAdapterModelFieldValue, IPAdapterModelFieldValue,
@ -29,10 +31,12 @@ import type {
import { import {
zBoardFieldValue, zBoardFieldValue,
zBooleanFieldValue, zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue, zColorFieldValue,
zControlNetModelFieldValue, zControlNetModelFieldValue,
zEnumFieldValue, zEnumFieldValue,
zFloatFieldValue, zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldValue, zImageFieldValue,
zIntegerFieldValue, zIntegerFieldValue,
zIPAdapterModelFieldValue, zIPAdapterModelFieldValue,
@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => { fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue); fieldValueReducer(state, action, zT5EncoderModelFieldValue);
}, },
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => { fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue); fieldValueReducer(state, action, zEnumFieldValue);
}, },
@ -408,6 +418,8 @@ export const {
fieldStringValueChanged, fieldStringValueChanged,
fieldVaeModelValueChanged, fieldVaeModelValueChanged,
fieldT5EncoderValueChanged, fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset, nodeEditorReset,
nodeIsIntermediateChanged, nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,
@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldStringValueChanged, fieldStringValueChanged,
fieldVaeModelValueChanged, fieldVaeModelValueChanged,
fieldT5EncoderValueChanged, fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodesChanged, nodesChanged,
nodeIsIntermediateChanged, nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,

View File

@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'), name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'), name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
zT2IAdapterModelFieldType, zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType, zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType, zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType, zColorFieldType,
zSchedulerFieldType, zSchedulerFieldType,
]); ]);
@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate => export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success; zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregio // #endregion
// #region FluxVAEModelField
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxVAEModelFieldValue,
});
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxVAEModelFieldType,
originalType: zFieldType.optional(),
default: zFluxVAEModelFieldValue,
});
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
zFluxVAEModelFieldInputInstance.safeParse(val).success;
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region CLIPEmbedModelField
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPEmbedModelFieldValue,
});
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPEmbedModelFieldValue,
});
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField // #region SchedulerField
@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue, zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue, zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue, zColorFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
]); ]);
@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
zT2IAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance, zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance, zT5EncoderModelFieldInputInstance,
zFluxVAEModelFieldInputInstance,
zCLIPEmbedModelFieldInputInstance,
zColorFieldInputInstance, zColorFieldInputInstance,
zSchedulerFieldInputInstance, zSchedulerFieldInputInstance,
]); ]);
@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
zT2IAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate, zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate, zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zColorFieldInputTemplate, zColorFieldInputTemplate,
zSchedulerFieldInputTemplate, zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate, zStatelessFieldInputTemplate,

View File

@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
VAEModelField: undefined, VAEModelField: undefined,
ControlNetModelField: undefined, ControlNetModelField: undefined,
T5EncoderModelField: undefined, T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
}; };
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
import type { import type {
BoardFieldInputTemplate, BoardFieldInputTemplate,
BooleanFieldInputTemplate, BooleanFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
ColorFieldInputTemplate, ColorFieldInputTemplate,
ControlNetModelFieldInputTemplate, ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate, EnumFieldInputTemplate,
@ -9,6 +10,7 @@ import type {
FieldType, FieldType,
FloatFieldInputTemplate, FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate, FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldInputTemplate, ImageFieldInputTemplate,
IntegerFieldInputTemplate, IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate, IPAdapterModelFieldInputTemplate,
@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
return template; return template;
}; };
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxVAEModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate, SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate, VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate, T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
} as const; } as const;
export const buildFieldInputTemplate = ( export const buildFieldInputTemplate = (

View File

@ -7,6 +7,7 @@ import {
isControlNetModelConfig, isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig, isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig, isFluxMainModelModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig, isIPAdapterModelConfig,
isLoRAModelConfig, isLoRAModelConfig,
isNonRefinerMainModelConfig, isNonRefinerMainModelConfig,
@ -52,3 +53,4 @@ export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToIm
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig); export const useVAEModels = buildModelsHook(isVAEModelConfig);
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);

View File

@ -5770,8 +5770,22 @@ export type components = {
use_cache?: boolean; use_cache?: boolean;
/** @description Flux model (Transformer) to load */ /** @description Flux model (Transformer) to load */
model: components["schemas"]["ModelIdentifierField"]; model: components["schemas"]["ModelIdentifierField"];
/** @description T5 tokenizer and text encoder */ /**
t5_encoder: components["schemas"]["ModelIdentifierField"]; * T5 Encoder
* @description T5 tokenizer and text encoder
*/
t5_encoder_model: components["schemas"]["ModelIdentifierField"];
/**
* CLIP Embed
* @description CLIP Embed loader
*/
clip_embed_model: components["schemas"]["ModelIdentifierField"];
/**
* VAE
* @description VAE model to load
* @default null
*/
vae_model?: components["schemas"]["ModelIdentifierField"];
/** /**
* type * type
* @default flux_model_loader * @default flux_model_loader
@ -15097,7 +15111,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility. * used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string} * @enum {string}
*/ */
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */ /** UNetField */
UNetField: { UNetField: {
/** @description Info to load unet submodel */ /** @description Info to load unet submodel */

View File

@ -51,7 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig']; export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig']; export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type T5EncoderModelConfig = S['T5EncoderConfig']; export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig']; export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig']; export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
@ -82,6 +82,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
return config.type === 'vae'; return config.type === 'vae';
}; };
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => { export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
return config.type === 'controlnet'; return config.type === 'controlnet';
}; };