mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Setup flux model loading in the UI
This commit is contained in:
parent
1fa6bddc89
commit
5f59a828f9
@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
|
|
||||||
# region Model Field Types
|
# region Model Field Types
|
||||||
MainModel = "MainModelField"
|
MainModel = "MainModelField"
|
||||||
|
FluxMainModel = "FluxMainModelField"
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
@ -126,12 +127,14 @@ class FieldDescriptions:
|
|||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
cond = "Conditioning tensor"
|
cond = "Conditioning tensor"
|
||||||
controlnet_model = "ControlNet model to load"
|
controlnet_model = "ControlNet model to load"
|
||||||
vae_model = "VAE model to load"
|
vae_model = "VAE model to load"
|
||||||
lora_model = "LoRA model to load"
|
lora_model = "LoRA model to load"
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from einops import rearrange, repeat
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from flux.model import Flux
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from flux.modules.autoencoder import AutoEncoder
|
from optimum.quanto import qfloat8
|
||||||
from flux.sampling import denoise, get_noise, get_schedule, unpack
|
|
||||||
from flux.util import configs as flux_configs
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
@ -21,6 +20,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
InputField,
|
InputField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
@ -52,6 +52,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
|||||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
|
flux_model: ModelIdentifierField = InputField(
|
||||||
|
description="The Flux model",
|
||||||
|
input=Input.Any,
|
||||||
|
ui_type=UIType.FluxMainModel
|
||||||
|
)
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||||
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
|
||||||
default="raw", description="The type of quantization to use for the transformer model."
|
default="raw", description="The type of quantization to use for the transformer model."
|
||||||
|
@ -60,6 +60,12 @@ class CLIPField(BaseModel):
|
|||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerField(BaseModel):
|
||||||
|
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||||
|
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||||
|
|
||||||
|
|
||||||
class VAEField(BaseModel):
|
class VAEField(BaseModel):
|
||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
@ -122,6 +128,49 @@ class ModelIdentifierInvocation(BaseInvocation):
|
|||||||
return ModelIdentifierOutput(model=self.model)
|
return ModelIdentifierOutput(model=self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_model_loader_output")
|
||||||
|
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Flux base model loader output"""
|
||||||
|
|
||||||
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||||
|
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
|
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
|
||||||
|
class FluxModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a flux base model, outputting its submodels."""
|
||||||
|
|
||||||
|
model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.flux_model,
|
||||||
|
ui_type=UIType.FluxMainModel,
|
||||||
|
input=Input.Direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||||
|
model_key = self.model.key
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.models.exists(model_key):
|
||||||
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
|
return FluxModelLoaderOutput(
|
||||||
|
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||||
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||||
|
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||||
|
vae=VAEField(vae=vae),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
title="Main Model",
|
title="Main Model",
|
||||||
|
@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
|||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
|
Flux = "flux"
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
@ -74,6 +75,7 @@ class SubModelType(str, Enum):
|
|||||||
"""Submodel type."""
|
"""Submodel type."""
|
||||||
|
|
||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
|
Transformer = "transformer"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
|
@ -95,6 +95,7 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
|
"FluxPipeline": ModelType.Main,
|
||||||
"StableDiffusionPipeline": ModelType.Main,
|
"StableDiffusionPipeline": ModelType.Main,
|
||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
@ -626,6 +627,10 @@ class FolderProbeBase(ProbeBase):
|
|||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
with open(f"{self.model_path}/model_index.json", "r") as file:
|
||||||
|
conf = json.load(file)
|
||||||
|
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
|
||||||
|
return BaseModelType.Flux
|
||||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
unet_conf = json.load(file)
|
unet_conf = json.load(file)
|
||||||
if unet_conf["cross_attention_dim"] == 768:
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
|
@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
|||||||
'sd-2': 'teal',
|
'sd-2': 'teal',
|
||||||
sdxl: 'invokeBlue',
|
sdxl: 'invokeBlue',
|
||||||
'sdxl-refiner': 'invokeBlue',
|
'sdxl-refiner': 'invokeBlue',
|
||||||
|
flux: 'invokeBlue',
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelBaseBadge = ({ base }: Props) => {
|
const ModelBaseBadge = ({ base }: Props) => {
|
||||||
|
@ -14,6 +14,8 @@ import {
|
|||||||
isEnumFieldInputTemplate,
|
isEnumFieldInputTemplate,
|
||||||
isFloatFieldInputInstance,
|
isFloatFieldInputInstance,
|
||||||
isFloatFieldInputTemplate,
|
isFloatFieldInputTemplate,
|
||||||
|
isFluxMainModelFieldInputInstance,
|
||||||
|
isFluxMainModelFieldInputTemplate,
|
||||||
isImageFieldInputInstance,
|
isImageFieldInputInstance,
|
||||||
isImageFieldInputTemplate,
|
isImageFieldInputTemplate,
|
||||||
isIntegerFieldInputInstance,
|
isIntegerFieldInputInstance,
|
||||||
@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
|||||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||||
|
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||||
@ -69,6 +72,7 @@ type InputFieldProps = {
|
|||||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||||
|
window.console.log("Hit 0")
|
||||||
|
|
||||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useFluxModels } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const FluxMainModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useFluxModels();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: MainModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldMainModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(FluxMainModelFieldInputComponent);
|
@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||||
const zModelType = z.enum([
|
const zModelType = z.enum([
|
||||||
'main',
|
'main',
|
||||||
'vae',
|
'vae',
|
||||||
@ -76,6 +76,7 @@ const zModelType = z.enum([
|
|||||||
]);
|
]);
|
||||||
const zSubModelType = z.enum([
|
const zSubModelType = z.enum([
|
||||||
'unet',
|
'unet',
|
||||||
|
'transformer',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
'text_encoder_2',
|
'text_encoder_2',
|
||||||
'tokenizer',
|
'tokenizer',
|
||||||
|
@ -31,6 +31,7 @@ export const MODEL_TYPES = [
|
|||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
|
'FluxMainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'VaeModelField',
|
'VaeModelField',
|
||||||
@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
LatentsField: 'pink.500',
|
LatentsField: 'pink.500',
|
||||||
LoRAModelField: 'teal.500',
|
LoRAModelField: 'teal.500',
|
||||||
MainModelField: 'teal.500',
|
MainModelField: 'teal.500',
|
||||||
|
FluxMainModelField: 'teal.500',
|
||||||
SDXLMainModelField: 'teal.500',
|
SDXLMainModelField: 'teal.500',
|
||||||
SDXLRefinerModelField: 'teal.500',
|
SDXLRefinerModelField: 'teal.500',
|
||||||
SpandrelImageToImageModelField: 'teal.500',
|
SpandrelImageToImageModelField: 'teal.500',
|
||||||
@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
T2IAdapterField: 'teal.500',
|
T2IAdapterField: 'teal.500',
|
||||||
T2IAdapterModelField: 'teal.500',
|
T2IAdapterModelField: 'teal.500',
|
||||||
UNetField: 'red.500',
|
UNetField: 'red.500',
|
||||||
|
TransformerField: 'red.500',
|
||||||
VAEField: 'blue.500',
|
VAEField: 'blue.500',
|
||||||
VAEModelField: 'teal.500',
|
VAEModelField: 'teal.500',
|
||||||
};
|
};
|
||||||
|
@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('SDXLMainModelField'),
|
name: z.literal('SDXLMainModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('FluxMainModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('SDXLRefinerModelField'),
|
name: z.literal('SDXLRefinerModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zModelIdentifierFieldType,
|
zModelIdentifierFieldType,
|
||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
|
zFluxMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
|
|||||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region FluxMainModelField
|
||||||
|
|
||||||
|
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||||
|
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zFluxMainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zFluxMainModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zFluxMainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
|
type: zFluxMainModelFieldType,
|
||||||
|
});
|
||||||
|
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
|
||||||
|
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
|
||||||
|
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
|
||||||
|
zFluxMainModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
|
||||||
|
zFluxMainModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region SDXLRefinerModelField
|
// #region SDXLRefinerModelField
|
||||||
|
|
||||||
/** @alias */ // tells knip to ignore this duplicate export
|
/** @alias */ // tells knip to ignore this duplicate export
|
||||||
@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zModelIdentifierFieldValue,
|
zModelIdentifierFieldValue,
|
||||||
zMainModelFieldValue,
|
zMainModelFieldValue,
|
||||||
zSDXLMainModelFieldValue,
|
zSDXLMainModelFieldValue,
|
||||||
|
zFluxMainModelFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
zVAEModelFieldValue,
|
zVAEModelFieldValue,
|
||||||
zLoRAModelFieldValue,
|
zLoRAModelFieldValue,
|
||||||
@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zBoardFieldInputInstance,
|
zBoardFieldInputInstance,
|
||||||
zModelIdentifierFieldInputInstance,
|
zModelIdentifierFieldInputInstance,
|
||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
|
zFluxMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zBoardFieldInputTemplate,
|
zBoardFieldInputTemplate,
|
||||||
zModelIdentifierFieldInputTemplate,
|
zModelIdentifierFieldInputTemplate,
|
||||||
zMainModelFieldInputTemplate,
|
zMainModelFieldInputTemplate,
|
||||||
|
zFluxMainModelFieldInputTemplate,
|
||||||
zSDXLMainModelFieldInputTemplate,
|
zSDXLMainModelFieldInputTemplate,
|
||||||
zSDXLRefinerModelFieldInputTemplate,
|
zSDXLRefinerModelFieldInputTemplate,
|
||||||
zVAEModelFieldInputTemplate,
|
zVAEModelFieldInputTemplate,
|
||||||
@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
|||||||
zBoardFieldOutputTemplate,
|
zBoardFieldOutputTemplate,
|
||||||
zModelIdentifierFieldOutputTemplate,
|
zModelIdentifierFieldOutputTemplate,
|
||||||
zMainModelFieldOutputTemplate,
|
zMainModelFieldOutputTemplate,
|
||||||
|
zFluxMainModelFieldOutputTemplate,
|
||||||
zSDXLMainModelFieldOutputTemplate,
|
zSDXLMainModelFieldOutputTemplate,
|
||||||
zSDXLRefinerModelFieldOutputTemplate,
|
zSDXLRefinerModelFieldOutputTemplate,
|
||||||
zVAEModelFieldOutputTemplate,
|
zVAEModelFieldOutputTemplate,
|
||||||
|
@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
|
|||||||
isCollection: false,
|
isCollection: false,
|
||||||
isCollectionOrScalar: false,
|
isCollectionOrScalar: false,
|
||||||
},
|
},
|
||||||
|
FluxMainModelField: {
|
||||||
|
name: 'FluxMainModelField',
|
||||||
|
isCollection: false,
|
||||||
|
isCollectionOrScalar: false,
|
||||||
|
},
|
||||||
SDXLMainModelField: {
|
SDXLMainModelField: {
|
||||||
name: 'SDXLMainModelField',
|
name: 'SDXLMainModelField',
|
||||||
isCollection: false,
|
isCollection: false,
|
||||||
|
@ -27,7 +27,7 @@ const zScheduler = z.enum([
|
|||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'lcm',
|
'lcm',
|
||||||
]);
|
]);
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||||
const zMainModel = z.object({
|
const zMainModel = z.object({
|
||||||
model_name: z.string().min(1),
|
model_name: z.string().min(1),
|
||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([
|
|||||||
'ONNXModelField',
|
'ONNXModelField',
|
||||||
'Scheduler',
|
'Scheduler',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
|
'FluxMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'string',
|
'string',
|
||||||
'StringCollection',
|
'StringCollection',
|
||||||
@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
value: zMainOrOnnxModel.optional(),
|
value: zMainOrOnnxModel.optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('FluxMainModelField'),
|
||||||
|
value: zMainModel.optional(),
|
||||||
|
});
|
||||||
|
|
||||||
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
|
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('SDXLRefinerModelField'),
|
type: z.literal('SDXLRefinerModelField'),
|
||||||
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
|
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
|
||||||
@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
|
|||||||
zMainModelInputFieldValue,
|
zMainModelInputFieldValue,
|
||||||
zSchedulerInputFieldValue,
|
zSchedulerInputFieldValue,
|
||||||
zSDXLMainModelInputFieldValue,
|
zSDXLMainModelInputFieldValue,
|
||||||
|
zFluxMainModelInputFieldValue,
|
||||||
zSDXLRefinerModelInputFieldValue,
|
zSDXLRefinerModelInputFieldValue,
|
||||||
zStringCollectionInputFieldValue,
|
zStringCollectionInputFieldValue,
|
||||||
zStringPolymorphicInputFieldValue,
|
zStringPolymorphicInputFieldValue,
|
||||||
|
@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
|
||||||
const zModelName = z.string().min(3);
|
const zModelName = z.string().min(3);
|
||||||
export const zModelIdentifier = z.object({
|
export const zModelIdentifier = z.object({
|
||||||
model_name: zModelName,
|
model_name: zModelName,
|
||||||
|
@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
|||||||
});
|
});
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region FluxMainModelField
|
||||||
|
const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('FluxMainModelField'),
|
||||||
|
});
|
||||||
|
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||||
|
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
type: zFluxMainModelFieldType,
|
||||||
|
value: zFluxMainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||||
|
type: zFluxMainModelFieldType,
|
||||||
|
});
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region SDXLRefinerModelField
|
// #region SDXLRefinerModelField
|
||||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('SDXLRefinerModelField'),
|
name: z.literal('SDXLRefinerModelField'),
|
||||||
@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zBoardFieldType,
|
zBoardFieldType,
|
||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
|
zFluxMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zBoardFieldInputInstance,
|
zBoardFieldInputInstance,
|
||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
|
zFluxMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
zLoRAModelFieldInputInstance,
|
zLoRAModelFieldInputInstance,
|
||||||
@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([
|
|||||||
zBoardFieldOutputInstance,
|
zBoardFieldOutputInstance,
|
||||||
zMainModelFieldOutputInstance,
|
zMainModelFieldOutputInstance,
|
||||||
zSDXLMainModelFieldOutputInstance,
|
zSDXLMainModelFieldOutputInstance,
|
||||||
|
zFluxMainModelFieldOutputInstance,
|
||||||
zSDXLRefinerModelFieldOutputInstance,
|
zSDXLRefinerModelFieldOutputInstance,
|
||||||
zVAEModelFieldOutputInstance,
|
zVAEModelFieldOutputInstance,
|
||||||
zLoRAModelFieldOutputInstance,
|
zLoRAModelFieldOutputInstance,
|
||||||
|
@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
MainModelField: undefined,
|
MainModelField: undefined,
|
||||||
SchedulerField: 'euler',
|
SchedulerField: 'euler',
|
||||||
SDXLMainModelField: undefined,
|
SDXLMainModelField: undefined,
|
||||||
|
FluxMainModelField: undefined,
|
||||||
SDXLRefinerModelField: undefined,
|
SDXLRefinerModelField: undefined,
|
||||||
StringField: '',
|
StringField: '',
|
||||||
T2IAdapterModelField: undefined,
|
T2IAdapterModelField: undefined,
|
||||||
|
@ -8,6 +8,7 @@ import type {
|
|||||||
FieldInputTemplate,
|
FieldInputTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
FloatFieldInputTemplate,
|
FloatFieldInputTemplate,
|
||||||
|
FluxMainModelFieldInputTemplate,
|
||||||
ImageFieldInputTemplate,
|
ImageFieldInputTemplate,
|
||||||
IntegerFieldInputTemplate,
|
IntegerFieldInputTemplate,
|
||||||
IPAdapterModelFieldInputTemplate,
|
IPAdapterModelFieldInputTemplate,
|
||||||
@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: FluxMainModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
MainModelField: buildMainModelFieldInputTemplate,
|
MainModelField: buildMainModelFieldInputTemplate,
|
||||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||||
|
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||||
StringField: buildStringFieldInputTemplate,
|
StringField: buildStringFieldInputTemplate,
|
||||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||||
|
@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
|
|||||||
'ModelIdentifier',
|
'ModelIdentifier',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
|
'FluxMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'VAEModelField',
|
'VAEModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
|
@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
|
|||||||
'sd-2': 'Stable Diffusion 2.x',
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
sdxl: 'Stable Diffusion XL',
|
sdxl: 'Stable Diffusion XL',
|
||||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||||
|
flux: 'Flux',
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
|||||||
'sd-2': 'SD2.X',
|
'sd-2': 'SD2.X',
|
||||||
sdxl: 'SDXL',
|
sdxl: 'SDXL',
|
||||||
'sdxl-refiner': 'SDXLR',
|
'sdxl-refiner': 'SDXLR',
|
||||||
|
flux: 'FLUX',
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
|
|||||||
maxClip: 24,
|
maxClip: 24,
|
||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
},
|
},
|
||||||
|
flux: {
|
||||||
|
maxClip: 0,
|
||||||
|
markers: [],
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5,6 +5,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
import {
|
import {
|
||||||
isControlNetModelConfig,
|
isControlNetModelConfig,
|
||||||
isControlNetOrT2IAdapterModelConfig,
|
isControlNetOrT2IAdapterModelConfig,
|
||||||
|
isFluxMainModelModelConfig,
|
||||||
isIPAdapterModelConfig,
|
isIPAdapterModelConfig,
|
||||||
isLoRAModelConfig,
|
isLoRAModelConfig,
|
||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
@ -35,6 +36,7 @@ const buildModelsHook =
|
|||||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||||
|
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||||
|
File diff suppressed because one or more lines are too long
@ -118,6 +118,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
|||||||
return config.type === 'main' && config.base === 'sdxl';
|
return config.type === 'main' && config.base === 'sdxl';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && config.base === 'flux';
|
||||||
|
};
|
||||||
|
|
||||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user