Setup flux model loading in the UI

This commit is contained in:
Brandon Rising 2024-08-12 14:04:23 -04:00 committed by Brandon
parent 1fa6bddc89
commit 5f59a828f9
22 changed files with 463 additions and 18 deletions

View File

@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@ -126,12 +127,14 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"

View File

@ -1,14 +1,13 @@
from pathlib import Path
from typing import Literal
from pydantic import Field
import accelerate
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from einops import rearrange, repeat
from flux.model import Flux
from flux.modules.autoencoder import AutoEncoder
from flux.sampling import denoise, get_noise, get_schedule, unpack
from flux.util import configs as flux_configs
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import ModelIdentifierField
from optimum.quanto import qfloat8
from PIL import Image
from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding
@ -21,6 +20,7 @@ from invokeai.app.invocations.fields import (
InputField,
WithBoard,
WithMetadata,
UIType,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -52,6 +52,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
flux_model: ModelIdentifierField = InputField(
description="The Flux model",
input=Input.Any,
ui_type=UIType.FluxMainModel
)
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
default="raw", description="The type of quantization to use for the transformer model."

View File

@ -60,6 +60,12 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -122,6 +128,49 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@ -74,6 +75,7 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"

View File

@ -95,6 +95,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@ -626,6 +627,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(f"{self.model_path}/model_index.json", "r") as file:
conf = json.load(file)
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
return BaseModelType.Flux
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:

View File

@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
'sd-2': 'teal',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'invokeBlue',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@ -14,6 +14,8 @@ import {
isEnumFieldInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@ -69,6 +72,7 @@ type InputFieldProps = {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
window.console.log("Hit 0")
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelType = z.enum([
'main',
'vae',
@ -76,6 +76,7 @@ const zModelType = z.enum([
]);
const zSubModelType = z.enum([
'unet',
'transformer',
'text_encoder',
'text_encoder_2',
'tokenizer',

View File

@ -31,6 +31,7 @@ export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
@ -61,6 +62,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LatentsField: 'pink.500',
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',
@ -68,6 +70,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxMainModelFieldType,
originalType: zFieldType.optional(),
default: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFluxMainModelFieldType,
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
isCollection: false,
isCollectionOrScalar: false,
},
FluxMainModelField: {
name: 'FluxMainModelField',
isCollection: false,
isCollectionOrScalar: false,
},
SDXLMainModelField: {
name: 'SDXLMainModelField',
isCollection: false,

View File

@ -27,7 +27,7 @@ const zScheduler = z.enum([
'kdpm_2_a',
'lcm',
]);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zMainModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
@ -89,6 +89,7 @@ const zFieldTypeV1 = z.enum([
'ONNXModelField',
'Scheduler',
'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField',
'string',
'StringCollection',
@ -417,6 +418,11 @@ const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({
value: zMainOrOnnxModel.optional(),
});
const zFluxMainModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FluxMainModelField'),
value: zMainModel.optional(),
});
const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('SDXLRefinerModelField'),
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
@ -572,6 +578,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zMainModelInputFieldValue,
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zFluxMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,

View File

@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
model_name: zModelName,

View File

@ -203,6 +203,20 @@ const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
});
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zFluxMainModelFieldType,
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zFluxMainModelFieldType,
});
// #endregion
// #region SDXLRefinerModelField
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
@ -338,6 +352,7 @@ const zStatefulFieldType = z.union([
zBoardFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
@ -377,6 +392,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance,
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
@ -401,6 +417,7 @@ const zStatefulFieldOutputInstance = z.union([
zBoardFieldOutputInstance,
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zFluxMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,

View File

@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@ -8,6 +8,7 @@ import type {
FieldInputTemplate,
FieldType,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
ImageFieldInputTemplate,
IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
@ -180,6 +181,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
return template;
};
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@ -386,6 +401,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

View File

@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
'ModelIdentifier',
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',

View File

@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'Flux',
};
/**
@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
'sd-2': 'SD2.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
};
/**
@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
flux: {
maxClip: 0,
markers: [],
},
};
/**

View File

@ -5,6 +5,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
@ -35,6 +36,7 @@ const buildModelsHook =
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -118,6 +118,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'sdxl';
};
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
};