mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
87261bdbc9 | ||
|
4e4b6c6dbc | ||
|
5e8cf9fb6a | ||
|
c738fe051f | ||
|
29fe1533f2 | ||
|
77090070bd | ||
|
6ba9b1b6b0 | ||
|
c578b8df1e | ||
|
cad9a41433 | ||
|
5fefb3b0f4 | ||
|
5284a870b0 | ||
|
e064377c05 | ||
|
3e569c8312 | ||
|
16825ee6e9 | ||
|
3f5340fa53 | ||
|
f2a1a39b33 | ||
|
326de55d3e | ||
|
b2df909570 | ||
|
026ac36b06 | ||
|
92125e5fd2 | ||
|
c0c139da88 | ||
|
404ad6a7fd | ||
|
fc39086fb4 | ||
|
cd215700fe | ||
|
e97fd85904 |
@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
VAEModel = "VAEModelField"
|
VAEModel = "VAEModelField"
|
||||||
|
FluxVAEModel = "FluxVAEModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
T2IAdapterModel = "T2IAdapterModelField"
|
T2IAdapterModel = "T2IAdapterModelField"
|
||||||
T5EncoderModel = "T5EncoderModelField"
|
T5EncoderModel = "T5EncoderModelField"
|
||||||
|
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -128,6 +130,7 @@ class FieldDescriptions:
|
|||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
t5_encoder = "T5 tokenizer and text encoder"
|
t5_encoder = "T5 tokenizer and text encoder"
|
||||||
|
clip_embed_model = "CLIP Embed loader"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
transformer = "Transformer"
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
|
@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||||
|
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||||
|
t5_embeddings = self._t5_encode(context)
|
||||||
|
clip_embeddings = self._clip_encode(context)
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
)
|
)
|
||||||
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return FluxConditioningOutput.build(conditioning_name)
|
return FluxConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||||
# Load CLIP.
|
|
||||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
|
||||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
|
||||||
|
|
||||||
# Load T5.
|
|
||||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||||
|
|
||||||
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt_embeds = t5_encoder(prompt)
|
prompt_embeds = t5_encoder(prompt)
|
||||||
|
|
||||||
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||||
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
|
|
||||||
|
prompt = [self.prompt]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
clip_text_encoder_info as clip_text_encoder,
|
clip_text_encoder_info as clip_text_encoder,
|
||||||
clip_tokenizer_info as clip_tokenizer,
|
clip_tokenizer_info as clip_tokenizer,
|
||||||
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
pooled_prompt_embeds = clip_encoder(prompt)
|
pooled_prompt_embeds = clip_encoder(prompt)
|
||||||
|
|
||||||
assert isinstance(prompt_embeds, torch.Tensor)
|
|
||||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
return pooled_prompt_embeds
|
||||||
|
@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Load the conditioning data.
|
latents = self._run_diffusion(context)
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
|
||||||
assert len(cond_data.conditionings) == 1
|
|
||||||
flux_conditioning = cond_data.conditionings[0]
|
|
||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
|
||||||
|
|
||||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
|
||||||
image = self._run_vae_decoding(context, latents)
|
image = self._run_vae_decoding(context, latents)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def _run_diffusion(
|
def _run_diffusion(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
clip_embeddings: torch.Tensor,
|
|
||||||
t5_embeddings: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
|
||||||
inference_dtype = torch.bfloat16
|
inference_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Load the conditioning data.
|
||||||
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
|
assert len(cond_data.conditionings) == 1
|
||||||
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||||
|
t5_embeddings = flux_conditioning.t5_embeds
|
||||||
|
clip_embeddings = flux_conditioning.clip_embeds
|
||||||
|
|
||||||
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
|
||||||
# Prepare input noise.
|
# Prepare input noise.
|
||||||
x = get_noise(
|
x = get_noise(
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
img, img_ids = prepare_latent_img_patches(x)
|
x, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
image_seq_len=img.shape[1],
|
image_seq_len=x.shape[1],
|
||||||
shift=not is_schnell,
|
shift=not is_schnell,
|
||||||
)
|
)
|
||||||
|
|
||||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
|
||||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
|
||||||
# if the cache is not empty.
|
|
||||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
|
||||||
|
|
||||||
with transformer_info as transformer:
|
with transformer_info as transformer:
|
||||||
assert isinstance(transformer, Flux)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
img=img,
|
img=x,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
txt=t5_embeddings,
|
txt=t5_embeddings,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
|
@ -157,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
|||||||
title="Flux Main Model",
|
title="Flux Main Model",
|
||||||
tags=["model", "flux"],
|
tags=["model", "flux"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.3",
|
version="1.0.4",
|
||||||
classification=Classification.Prototype,
|
classification=Classification.Prototype,
|
||||||
)
|
)
|
||||||
class FluxModelLoaderInvocation(BaseInvocation):
|
class FluxModelLoaderInvocation(BaseInvocation):
|
||||||
@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
)
|
)
|
||||||
|
|
||||||
t5_encoder: ModelIdentifierField = InputField(
|
t5_encoder_model: ModelIdentifierField = InputField(
|
||||||
description=FieldDescriptions.t5_encoder,
|
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||||
ui_type=UIType.T5EncoderModel,
|
)
|
||||||
|
|
||||||
|
clip_embed_model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.clip_embed_model,
|
||||||
|
ui_type=UIType.CLIPEmbedModel,
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
|
title="CLIP Embed",
|
||||||
|
)
|
||||||
|
|
||||||
|
vae_model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||||
model_key = self.model.key
|
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||||
|
if not context.models.exists(key):
|
||||||
|
raise ValueError(f"Unknown model: {key}")
|
||||||
|
|
||||||
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
|
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
|
||||||
|
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
|
||||||
if not context.models.exists(model_key):
|
|
||||||
raise ValueError(f"Unknown model: {model_key}")
|
|
||||||
transformer = self._get_model(context, SubModelType.Transformer)
|
|
||||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
|
||||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
|
||||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
|
||||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
|
||||||
vae = self._get_model(context, SubModelType.VAE)
|
|
||||||
transformer_config = context.models.get_config(transformer)
|
transformer_config = context.models.get_config(transformer)
|
||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
|
|
||||||
@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
|
||||||
match submodel:
|
|
||||||
case SubModelType.Transformer:
|
|
||||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
|
||||||
case SubModelType.VAE:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
SubModelType.VAE,
|
|
||||||
"FLUX.1-schnell_ae",
|
|
||||||
ModelType.VAE,
|
|
||||||
BaseModelType.Flux,
|
|
||||||
)
|
|
||||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
submodel,
|
|
||||||
"clip-vit-large-patch14",
|
|
||||||
ModelType.CLIPEmbed,
|
|
||||||
BaseModelType.Any,
|
|
||||||
)
|
|
||||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
submodel,
|
|
||||||
self.t5_encoder.name,
|
|
||||||
ModelType.T5Encoder,
|
|
||||||
BaseModelType.Any,
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
|
||||||
|
|
||||||
def _pull_model_from_mm(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
submodel: SubModelType,
|
|
||||||
name: str,
|
|
||||||
type: ModelType,
|
|
||||||
base: BaseModelType,
|
|
||||||
):
|
|
||||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
|
||||||
if len(models) != 1:
|
|
||||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
|
||||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
|
@ -2,13 +2,13 @@
|
|||||||
"name": "FLUX Text to Image",
|
"name": "FLUX Text to Image",
|
||||||
"author": "InvokeAI",
|
"author": "InvokeAI",
|
||||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||||
"version": "1.0.0",
|
"version": "1.0.4",
|
||||||
"contact": "",
|
"contact": "",
|
||||||
"tags": "text2image, flux",
|
"tags": "text2image, flux",
|
||||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||||
"exposedFields": [
|
"exposedFields": [
|
||||||
{
|
{
|
||||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"fieldName": "model"
|
"fieldName": "model"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -20,8 +20,8 @@
|
|||||||
"fieldName": "num_steps"
|
"fieldName": "num_steps"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"fieldName": "t5_encoder"
|
"fieldName": "t5_encoder_model"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"meta": {
|
"meta": {
|
||||||
@ -30,12 +30,12 @@
|
|||||||
},
|
},
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{
|
{
|
||||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"type": "invocation",
|
"type": "invocation",
|
||||||
"data": {
|
"data": {
|
||||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"type": "flux_model_loader",
|
"type": "flux_model_loader",
|
||||||
"version": "1.0.3",
|
"version": "1.0.4",
|
||||||
"label": "",
|
"label": "",
|
||||||
"notes": "",
|
"notes": "",
|
||||||
"isOpen": true,
|
"isOpen": true,
|
||||||
@ -44,31 +44,25 @@
|
|||||||
"inputs": {
|
"inputs": {
|
||||||
"model": {
|
"model": {
|
||||||
"name": "model",
|
"name": "model",
|
||||||
"label": "Model (Starter Models can be found in Model Manager)",
|
"label": ""
|
||||||
"value": {
|
|
||||||
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
|
||||||
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
|
||||||
"name": "FLUX Dev (Quantized)",
|
|
||||||
"base": "flux",
|
|
||||||
"type": "main"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"t5_encoder": {
|
"t5_encoder_model": {
|
||||||
"name": "t5_encoder",
|
"name": "t5_encoder_model",
|
||||||
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
"label": ""
|
||||||
"value": {
|
},
|
||||||
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
"clip_embed_model": {
|
||||||
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
"name": "clip_embed_model",
|
||||||
"name": "t5_bnb_int8_quantized_encoder",
|
"label": ""
|
||||||
"base": "any",
|
},
|
||||||
"type": "t5_encoder"
|
"vae_model": {
|
||||||
}
|
"name": "vae_model",
|
||||||
|
"label": ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"position": {
|
"position": {
|
||||||
"x": 337.09365228062825,
|
"x": 381.1882713063478,
|
||||||
"y": 40.63469521079861
|
"y": -95.89663532854017
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -207,45 +201,45 @@
|
|||||||
],
|
],
|
||||||
"edges": [
|
"edges": [
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
"sourceHandle": "max_seq_len",
|
"sourceHandle": "max_seq_len",
|
||||||
"targetHandle": "t5_max_seq_len"
|
"targetHandle": "t5_max_seq_len"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
"sourceHandle": "vae",
|
"sourceHandle": "vae",
|
||||||
"targetHandle": "vae"
|
"targetHandle": "vae"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"sourceHandle": "transformer",
|
|
||||||
"targetHandle": "transformer"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
|
||||||
"type": "default",
|
|
||||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
|
||||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
"sourceHandle": "t5_encoder",
|
"sourceHandle": "t5_encoder",
|
||||||
"targetHandle": "t5_encoder"
|
"targetHandle": "t5_encoder"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
"sourceHandle": "clip",
|
"sourceHandle": "clip",
|
||||||
"targetHandle": "clip"
|
"targetHandle": "clip"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"sourceHandle": "transformer",
|
||||||
|
"targetHandle": "transformer"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
|
@ -111,16 +111,7 @@ def denoise(
|
|||||||
step_callback: Callable[[], None],
|
step_callback: Callable[[], None],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
):
|
):
|
||||||
dtype = model.txt_in.bias.dtype
|
# guidance_vec is ignored for schnell.
|
||||||
|
|
||||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
|
||||||
img = img.to(dtype=dtype)
|
|
||||||
img_ids = img_ids.to(dtype=dtype)
|
|
||||||
txt = txt.to(dtype=dtype)
|
|
||||||
txt_ids = txt_ids.to(dtype=dtype)
|
|
||||||
vec = vec.to(dtype=dtype)
|
|
||||||
|
|
||||||
# this is ignored for schnell
|
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
|
|||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
# Generate patch position ids.
|
# Generate patch position ids.
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
return img, img_ids
|
return img, img_ids
|
||||||
|
@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
config.path = str(self._get_model_path(config))
|
config.path = str(self._get_model_path(config))
|
||||||
|
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||||
loaded_model = self._load_model(config, submodel_type)
|
loaded_model = self._load_model(config, submodel_type)
|
||||||
|
|
||||||
self._ram_cache.put(
|
self._ram_cache.put(
|
||||||
|
@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def exists(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cache_size(self) -> int:
|
def cache_size(self) -> int:
|
||||||
"""Get the total size of the models currently cached."""
|
"""Get the total size of the models currently cached."""
|
||||||
|
@ -1,22 +1,6 @@
|
|||||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||||
# TODO: Add Stalker's proper name to copyright
|
# TODO: Add Stalker's proper name to copyright
|
||||||
"""
|
""" """
|
||||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
||||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
|
||||||
grows larger than a preset maximum, then the least recently used
|
|
||||||
model will be cleared and (re)loaded from disk when next needed.
|
|
||||||
|
|
||||||
The cache returns context manager generators designed to load the
|
|
||||||
model into the GPU within the context, and unload outside the
|
|
||||||
context. Use like this:
|
|
||||||
|
|
||||||
cache = ModelCache(max_cache_size=7.5)
|
|
||||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
||||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
||||||
do_something_in_GPU(SD1,SD2)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
|||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Size of a GB in bytes.
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
GB = 2**30
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
|
||||||
|
|
||||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
|
||||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
|
||||||
|
|
||||||
# actual size of a gig
|
|
||||||
GIG = 1073741824
|
|
||||||
|
|
||||||
# Size of a MB in bytes.
|
# Size of a MB in bytes.
|
||||||
MB = 2**20
|
MB = 2**20
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(ModelCacheBase[AnyModel]):
|
class ModelCache(ModelCacheBase[AnyModel]):
|
||||||
"""Implementation of ModelCacheBase."""
|
"""A cache for managing models in memory.
|
||||||
|
|
||||||
|
The cache is based on two levels of model storage:
|
||||||
|
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||||
|
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||||
|
|
||||||
|
The model cache is based on the following assumptions:
|
||||||
|
- storage_device_mem_size > execution_device_mem_size
|
||||||
|
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||||
|
|
||||||
|
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||||
|
the execution_device.
|
||||||
|
|
||||||
|
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||||
|
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||||
|
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||||
|
|
||||||
|
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||||
|
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||||
|
the context, and unload outside the context.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||||
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||||
|
do_something_on_gpu(SD1)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float,
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
max_vram_cache_size: float,
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
precision: torch.dtype = torch.float16,
|
|
||||||
sequential_offload: bool = False,
|
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
sha_chunksize: int = 16777216,
|
|
||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
logger: Optional[Logger] = None,
|
logger: Optional[Logger] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the model RAM cache.
|
Initialize the model RAM cache.
|
||||||
|
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||||
|
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
"""
|
"""
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self._precision: torch.dtype = precision
|
|
||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._max_vram_cache_size: float = max_vram_cache_size
|
self._max_vram_cache_size: float = max_vram_cache_size
|
||||||
self._execution_device: torch.device = execution_device
|
self._execution_device: torch.device = execution_device
|
||||||
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
total += cache_record.size
|
total += cache_record.size
|
||||||
return total
|
return total
|
||||||
|
|
||||||
def exists(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
|
||||||
key = self._make_cache_key(key, submodel_type)
|
|
||||||
return key in self._cached_models
|
|
||||||
|
|
||||||
def put(
|
def put(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
# more stats
|
# more stats
|
||||||
if self.stats:
|
if self.stats:
|
||||||
stats_name = stats_name or key
|
stats_name = stats_name or key
|
||||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||||
self.stats.in_cache = len(self._cached_models)
|
self.stats.in_cache = len(self._cached_models)
|
||||||
self.stats.loaded_model_sizes[stats_name] = max(
|
self.stats.loaded_model_sizes[stats_name] = max(
|
||||||
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
return model_key
|
return model_key
|
||||||
|
|
||||||
def offload_unlocked_models(self, size_required: int) -> None:
|
def offload_unlocked_models(self, size_required: int) -> None:
|
||||||
"""Move any unused models from VRAM."""
|
"""Offload models from the execution_device to make room for size_required.
|
||||||
reserved = self._max_vram_cache_size * GIG
|
|
||||||
|
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||||
|
"""
|
||||||
|
reserved = self._max_vram_cache_size * GB
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
cache_entry.loaded = False
|
cache_entry.loaded = False
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||||
" estimated size may be incorrect. Estimated model size:"
|
" estimated size may be incorrect. Estimated model size:"
|
||||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def print_cuda_stats(self) -> None:
|
def print_cuda_stats(self) -> None:
|
||||||
"""Log CUDA diagnostics."""
|
"""Log CUDA diagnostics."""
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||||
|
|
||||||
in_ram_models = 0
|
in_ram_models = 0
|
||||||
in_vram_models = 0
|
in_vram_models = 0
|
||||||
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def make_room(self, size: int) -> None:
|
def make_room(self, size: int) -> None:
|
||||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||||
# calculate how much memory this model will require
|
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||||
|
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||||
|
garbage-collected.
|
||||||
|
"""
|
||||||
bytes_needed = size
|
bytes_needed = size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||||
current_size = self.cache_size()
|
current_size = self.cache_size()
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
if current_size + bytes_needed > maximum_size:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||||
f" {(bytes_needed/GIG):.2f} GB"
|
f" {(bytes_needed/GB):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||||
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
models_cleared += 1
|
models_cleared += 1
|
||||||
|
@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
|
|
||||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||||
scb = state_dict.pop(prefix + "SCB", None)
|
scb = state_dict.pop(prefix + "SCB", None)
|
||||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
|
||||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
# Currently, we only support weight_format=0.
|
||||||
|
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||||
|
assert weight_format == 0
|
||||||
|
|
||||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||||
# rather than raising an exception to correctly implement this API.
|
# rather than raising an exception to correctly implement this API.
|
||||||
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
|||||||
)
|
)
|
||||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||||
|
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||||
|
new_state = bnb.MatmulLtState()
|
||||||
|
new_state.threshold = self.state.threshold
|
||||||
|
new_state.has_fp16_weights = False
|
||||||
|
new_state.use_pool = self.state.use_pool
|
||||||
|
self.state = new_state
|
||||||
|
|
||||||
|
|
||||||
def _convert_linear_layers_to_llm_8bit(
|
def _convert_linear_layers_to_llm_8bit(
|
||||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||||
|
@ -43,6 +43,11 @@ class FLUXConditioningInfo:
|
|||||||
clip_embeds: torch.Tensor
|
clip_embeds: torch.Tensor
|
||||||
t5_embeds: torch.Tensor
|
t5_embeds: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||||
|
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningFieldData:
|
class ConditioningFieldData:
|
||||||
|
@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
from invokeai.backend.util.util import Chdir, directory_size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GIG",
|
|
||||||
"directory_size",
|
"directory_size",
|
||||||
"Chdir",
|
"Chdir",
|
||||||
"InvokeAILogger",
|
"InvokeAILogger",
|
||||||
|
@ -7,9 +7,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
# actual size of a gig
|
|
||||||
GIG = 1073741824
|
|
||||||
|
|
||||||
|
|
||||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
|
|||||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||||
|
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
@ -39,10 +40,17 @@ interface Props {
|
|||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
|
selectedStylePresetId?: string;
|
||||||
destination?: InvokeTabName | undefined;
|
destination?: InvokeTabName | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
const App = ({
|
||||||
|
config = DEFAULT_CONFIG,
|
||||||
|
selectedImage,
|
||||||
|
selectedWorkflowId,
|
||||||
|
selectedStylePresetId,
|
||||||
|
destination,
|
||||||
|
}: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(languageSelector);
|
||||||
const logger = useLogger('system');
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
|||||||
}
|
}
|
||||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedStylePresetId) {
|
||||||
|
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||||
|
}
|
||||||
|
}, [dispatch, selectedStylePresetId]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (destination) {
|
if (destination) {
|
||||||
dispatch(setActiveTab(destination));
|
dispatch(setActiveTab(destination));
|
||||||
|
@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
|
|||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
|
selectedStylePresetId?: string;
|
||||||
destination?: InvokeTabName;
|
destination?: InvokeTabName;
|
||||||
customStarUi?: CustomStarUi;
|
customStarUi?: CustomStarUi;
|
||||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||||
@ -66,6 +67,7 @@ const InvokeAIUI = ({
|
|||||||
queueId,
|
queueId,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
selectedWorkflowId,
|
selectedWorkflowId,
|
||||||
|
selectedStylePresetId,
|
||||||
destination,
|
destination,
|
||||||
customStarUi,
|
customStarUi,
|
||||||
socketOptions,
|
socketOptions,
|
||||||
@ -227,6 +229,7 @@ const InvokeAIUI = ({
|
|||||||
config={config}
|
config={config}
|
||||||
selectedImage={selectedImage}
|
selectedImage={selectedImage}
|
||||||
selectedWorkflowId={selectedWorkflowId}
|
selectedWorkflowId={selectedWorkflowId}
|
||||||
|
selectedStylePresetId={selectedStylePresetId}
|
||||||
destination={destination}
|
destination={destination}
|
||||||
/>
|
/>
|
||||||
</AppDndContext>
|
</AppDndContext>
|
||||||
|
@ -6,6 +6,8 @@ import {
|
|||||||
isBoardFieldInputTemplate,
|
isBoardFieldInputTemplate,
|
||||||
isBooleanFieldInputInstance,
|
isBooleanFieldInputInstance,
|
||||||
isBooleanFieldInputTemplate,
|
isBooleanFieldInputTemplate,
|
||||||
|
isCLIPEmbedModelFieldInputInstance,
|
||||||
|
isCLIPEmbedModelFieldInputTemplate,
|
||||||
isColorFieldInputInstance,
|
isColorFieldInputInstance,
|
||||||
isColorFieldInputTemplate,
|
isColorFieldInputTemplate,
|
||||||
isControlNetModelFieldInputInstance,
|
isControlNetModelFieldInputInstance,
|
||||||
@ -16,6 +18,8 @@ import {
|
|||||||
isFloatFieldInputTemplate,
|
isFloatFieldInputTemplate,
|
||||||
isFluxMainModelFieldInputInstance,
|
isFluxMainModelFieldInputInstance,
|
||||||
isFluxMainModelFieldInputTemplate,
|
isFluxMainModelFieldInputTemplate,
|
||||||
|
isFluxVAEModelFieldInputInstance,
|
||||||
|
isFluxVAEModelFieldInputTemplate,
|
||||||
isImageFieldInputInstance,
|
isImageFieldInputInstance,
|
||||||
isImageFieldInputTemplate,
|
isImageFieldInputTemplate,
|
||||||
isIntegerFieldInputInstance,
|
isIntegerFieldInputInstance,
|
||||||
@ -49,10 +53,12 @@ import { memo } from 'react';
|
|||||||
|
|
||||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||||
|
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||||
|
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||||
@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: ClipEmbedModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldCLIPEmbedValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Tooltip>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(CLIPEmbedModelFieldInputComponent);
|
@ -0,0 +1,60 @@
|
|||||||
|
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: VAEModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldFluxVAEModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Tooltip>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(FluxVAEModelFieldInputComponent);
|
@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
|||||||
import type {
|
import type {
|
||||||
BoardFieldValue,
|
BoardFieldValue,
|
||||||
BooleanFieldValue,
|
BooleanFieldValue,
|
||||||
|
CLIPEmbedModelFieldValue,
|
||||||
ColorFieldValue,
|
ColorFieldValue,
|
||||||
ControlNetModelFieldValue,
|
ControlNetModelFieldValue,
|
||||||
EnumFieldValue,
|
EnumFieldValue,
|
||||||
FieldValue,
|
FieldValue,
|
||||||
FloatFieldValue,
|
FloatFieldValue,
|
||||||
|
FluxVAEModelFieldValue,
|
||||||
ImageFieldValue,
|
ImageFieldValue,
|
||||||
IntegerFieldValue,
|
IntegerFieldValue,
|
||||||
IPAdapterModelFieldValue,
|
IPAdapterModelFieldValue,
|
||||||
@ -29,10 +31,12 @@ import type {
|
|||||||
import {
|
import {
|
||||||
zBoardFieldValue,
|
zBoardFieldValue,
|
||||||
zBooleanFieldValue,
|
zBooleanFieldValue,
|
||||||
|
zCLIPEmbedModelFieldValue,
|
||||||
zColorFieldValue,
|
zColorFieldValue,
|
||||||
zControlNetModelFieldValue,
|
zControlNetModelFieldValue,
|
||||||
zEnumFieldValue,
|
zEnumFieldValue,
|
||||||
zFloatFieldValue,
|
zFloatFieldValue,
|
||||||
|
zFluxVAEModelFieldValue,
|
||||||
zImageFieldValue,
|
zImageFieldValue,
|
||||||
zIntegerFieldValue,
|
zIntegerFieldValue,
|
||||||
zIPAdapterModelFieldValue,
|
zIPAdapterModelFieldValue,
|
||||||
@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
|
|||||||
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
||||||
},
|
},
|
||||||
|
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||||
|
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||||
|
},
|
||||||
|
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||||
|
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||||
|
},
|
||||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zEnumFieldValue);
|
fieldValueReducer(state, action, zEnumFieldValue);
|
||||||
},
|
},
|
||||||
@ -408,6 +418,8 @@ export const {
|
|||||||
fieldStringValueChanged,
|
fieldStringValueChanged,
|
||||||
fieldVaeModelValueChanged,
|
fieldVaeModelValueChanged,
|
||||||
fieldT5EncoderValueChanged,
|
fieldT5EncoderValueChanged,
|
||||||
|
fieldCLIPEmbedValueChanged,
|
||||||
|
fieldFluxVAEModelValueChanged,
|
||||||
nodeEditorReset,
|
nodeEditorReset,
|
||||||
nodeIsIntermediateChanged,
|
nodeIsIntermediateChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
|||||||
fieldStringValueChanged,
|
fieldStringValueChanged,
|
||||||
fieldVaeModelValueChanged,
|
fieldVaeModelValueChanged,
|
||||||
fieldT5EncoderValueChanged,
|
fieldT5EncoderValueChanged,
|
||||||
|
fieldCLIPEmbedValueChanged,
|
||||||
|
fieldFluxVAEModelValueChanged,
|
||||||
nodesChanged,
|
nodesChanged,
|
||||||
nodeIsIntermediateChanged,
|
nodeIsIntermediateChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
|
@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('T5EncoderModelField'),
|
name: z.literal('T5EncoderModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('CLIPEmbedModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
|
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('FluxVAEModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('SchedulerField'),
|
name: z.literal('SchedulerField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
|
|||||||
zT2IAdapterModelFieldType,
|
zT2IAdapterModelFieldType,
|
||||||
zSpandrelImageToImageModelFieldType,
|
zSpandrelImageToImageModelFieldType,
|
||||||
zT5EncoderModelFieldType,
|
zT5EncoderModelFieldType,
|
||||||
|
zCLIPEmbedModelFieldType,
|
||||||
|
zFluxVAEModelFieldType,
|
||||||
zColorFieldType,
|
zColorFieldType,
|
||||||
zSchedulerFieldType,
|
zSchedulerFieldType,
|
||||||
]);
|
]);
|
||||||
@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
|
|||||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
|
||||||
// #endregio
|
// #endregion
|
||||||
|
|
||||||
|
// #region FluxVAEModelField
|
||||||
|
|
||||||
|
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
|
||||||
|
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zFluxVAEModelFieldValue,
|
||||||
|
});
|
||||||
|
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zFluxVAEModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zFluxVAEModelFieldValue,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||||
|
|
||||||
|
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||||
|
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||||
|
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||||
|
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||||
|
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
|
||||||
|
// #endregion
|
||||||
|
|
||||||
|
// #region CLIPEmbedModelField
|
||||||
|
|
||||||
|
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||||
|
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zCLIPEmbedModelFieldValue,
|
||||||
|
});
|
||||||
|
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zCLIPEmbedModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zCLIPEmbedModelFieldValue,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
|
||||||
|
|
||||||
|
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||||
|
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||||
|
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||||
|
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||||
|
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region SchedulerField
|
// #region SchedulerField
|
||||||
|
|
||||||
@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zT2IAdapterModelFieldValue,
|
zT2IAdapterModelFieldValue,
|
||||||
zSpandrelImageToImageModelFieldValue,
|
zSpandrelImageToImageModelFieldValue,
|
||||||
zT5EncoderModelFieldValue,
|
zT5EncoderModelFieldValue,
|
||||||
|
zFluxVAEModelFieldValue,
|
||||||
|
zCLIPEmbedModelFieldValue,
|
||||||
zColorFieldValue,
|
zColorFieldValue,
|
||||||
zSchedulerFieldValue,
|
zSchedulerFieldValue,
|
||||||
]);
|
]);
|
||||||
@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zT2IAdapterModelFieldInputInstance,
|
zT2IAdapterModelFieldInputInstance,
|
||||||
zSpandrelImageToImageModelFieldInputInstance,
|
zSpandrelImageToImageModelFieldInputInstance,
|
||||||
zT5EncoderModelFieldInputInstance,
|
zT5EncoderModelFieldInputInstance,
|
||||||
|
zFluxVAEModelFieldInputInstance,
|
||||||
|
zCLIPEmbedModelFieldInputInstance,
|
||||||
zColorFieldInputInstance,
|
zColorFieldInputInstance,
|
||||||
zSchedulerFieldInputInstance,
|
zSchedulerFieldInputInstance,
|
||||||
]);
|
]);
|
||||||
@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zT2IAdapterModelFieldInputTemplate,
|
zT2IAdapterModelFieldInputTemplate,
|
||||||
zSpandrelImageToImageModelFieldInputTemplate,
|
zSpandrelImageToImageModelFieldInputTemplate,
|
||||||
zT5EncoderModelFieldInputTemplate,
|
zT5EncoderModelFieldInputTemplate,
|
||||||
|
zFluxVAEModelFieldInputTemplate,
|
||||||
|
zCLIPEmbedModelFieldInputTemplate,
|
||||||
zColorFieldInputTemplate,
|
zColorFieldInputTemplate,
|
||||||
zSchedulerFieldInputTemplate,
|
zSchedulerFieldInputTemplate,
|
||||||
zStatelessFieldInputTemplate,
|
zStatelessFieldInputTemplate,
|
||||||
|
@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
VAEModelField: undefined,
|
VAEModelField: undefined,
|
||||||
ControlNetModelField: undefined,
|
ControlNetModelField: undefined,
|
||||||
T5EncoderModelField: undefined,
|
T5EncoderModelField: undefined,
|
||||||
|
FluxVAEModelField: undefined,
|
||||||
|
CLIPEmbedModelField: undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||||
|
@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
|
|||||||
import type {
|
import type {
|
||||||
BoardFieldInputTemplate,
|
BoardFieldInputTemplate,
|
||||||
BooleanFieldInputTemplate,
|
BooleanFieldInputTemplate,
|
||||||
|
CLIPEmbedModelFieldInputTemplate,
|
||||||
ColorFieldInputTemplate,
|
ColorFieldInputTemplate,
|
||||||
ControlNetModelFieldInputTemplate,
|
ControlNetModelFieldInputTemplate,
|
||||||
EnumFieldInputTemplate,
|
EnumFieldInputTemplate,
|
||||||
@ -9,6 +10,7 @@ import type {
|
|||||||
FieldType,
|
FieldType,
|
||||||
FloatFieldInputTemplate,
|
FloatFieldInputTemplate,
|
||||||
FluxMainModelFieldInputTemplate,
|
FluxMainModelFieldInputTemplate,
|
||||||
|
FluxVAEModelFieldInputTemplate,
|
||||||
ImageFieldInputTemplate,
|
ImageFieldInputTemplate,
|
||||||
IntegerFieldInputTemplate,
|
IntegerFieldInputTemplate,
|
||||||
IPAdapterModelFieldInputTemplate,
|
IPAdapterModelFieldInputTemplate,
|
||||||
@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: CLIPEmbedModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: FluxVAEModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||||
|
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||||
|
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export const buildFieldInputTemplate = (
|
export const buildFieldInputTemplate = (
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
isControlNetModelConfig,
|
isControlNetModelConfig,
|
||||||
isControlNetOrT2IAdapterModelConfig,
|
isControlNetOrT2IAdapterModelConfig,
|
||||||
isFluxMainModelModelConfig,
|
isFluxMainModelModelConfig,
|
||||||
|
isFluxVAEModelConfig,
|
||||||
isIPAdapterModelConfig,
|
isIPAdapterModelConfig,
|
||||||
isLoRAModelConfig,
|
isLoRAModelConfig,
|
||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
@ -52,3 +53,4 @@ export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToIm
|
|||||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||||
|
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
|
||||||
|
@ -5770,8 +5770,22 @@ export type components = {
|
|||||||
use_cache?: boolean;
|
use_cache?: boolean;
|
||||||
/** @description Flux model (Transformer) to load */
|
/** @description Flux model (Transformer) to load */
|
||||||
model: components["schemas"]["ModelIdentifierField"];
|
model: components["schemas"]["ModelIdentifierField"];
|
||||||
/** @description T5 tokenizer and text encoder */
|
/**
|
||||||
t5_encoder: components["schemas"]["ModelIdentifierField"];
|
* T5 Encoder
|
||||||
|
* @description T5 tokenizer and text encoder
|
||||||
|
*/
|
||||||
|
t5_encoder_model: components["schemas"]["ModelIdentifierField"];
|
||||||
|
/**
|
||||||
|
* CLIP Embed
|
||||||
|
* @description CLIP Embed loader
|
||||||
|
*/
|
||||||
|
clip_embed_model: components["schemas"]["ModelIdentifierField"];
|
||||||
|
/**
|
||||||
|
* VAE
|
||||||
|
* @description VAE model to load
|
||||||
|
* @default null
|
||||||
|
*/
|
||||||
|
vae_model?: components["schemas"]["ModelIdentifierField"];
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default flux_model_loader
|
* @default flux_model_loader
|
||||||
@ -15097,7 +15111,7 @@ export type components = {
|
|||||||
* used, and the type will be ignored. They are included here for backwards compatibility.
|
* used, and the type will be ignored. They are included here for backwards compatibility.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||||
/** UNetField */
|
/** UNetField */
|
||||||
UNetField: {
|
UNetField: {
|
||||||
/** @description Info to load unet submodel */
|
/** @description Info to load unet submodel */
|
||||||
|
@ -51,7 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
|||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||||
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
||||||
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
||||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||||
@ -82,6 +82,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
|
|||||||
return config.type === 'vae';
|
return config.type === 'vae';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||||
|
return config.type === 'vae' && config.base === 'flux';
|
||||||
|
};
|
||||||
|
|
||||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||||
return config.type === 'controlnet';
|
return config.type === 'controlnet';
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user