feat(ui): update model identifier to be key (wip)

- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
This commit is contained in:
psychedelicious 2024-02-16 18:56:02 +11:00
parent 6df3c450e8
commit dab939f7d1
54 changed files with 267 additions and 453 deletions

View File

@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({
model_name: 'test_model',
base_model: 'sd-1',
model_type: 'main',
})
);
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
}, []);
return props.children;

View File

@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => {
let graph;
if (model && model.base_model === 'sdxl') {
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearSDXLTextToImageGraph(state);
} else {

View File

@ -30,8 +30,8 @@ export const addModelSelectedListener = () => {
const newModel = result.data;
const newBaseModel = newModel.base_model;
const didBaseModelChange = state.generation.model?.base_model !== newBaseModel;
const newBaseModel = newModel.base;
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@ -39,7 +39,7 @@ export const addModelSelectedListener = () => {
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
if (lora.base_model !== newBaseModel) {
if (lora.base !== newBaseModel) {
dispatch(loraRemoved(id));
modelsCleared += 1;
}
@ -47,14 +47,14 @@ export const addModelSelectedListener = () => {
// handle incompatible vae
const { vae } = state.generation;
if (vae && vae.base_model !== newBaseModel) {
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== newBaseModel) {
if (ca.model?.base !== newBaseModel) {
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
modelsCleared += 1;
}

View File

@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => {
return;
}
const isCurrentVAEAvailable = some(
action.payload.entities,
(m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model
);
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => {
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(
action.payload.entities,
(m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model
);
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key);
if (isLoRAAvailable) {
return;
@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => {
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model
);
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;

View File

@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
type UseGroupedModelComboboxArg<T extends AnyModelConfigEntity> = {
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
onChange: (value: T | null) => void;
@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = {
noOptionsMessage: () => string;
};
export const useGroupedModelCombobox = <T extends AnyModelConfigEntity>(
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();

View File

@ -105,7 +105,7 @@ const selector = createMemoizedSelector(
number: i + 1,
})
);
} else if (ca.model.base_model !== model?.base_model) {
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push(
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {

View File

@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
type UseModelComboboxArg<T extends AnyModelConfigEntity> = {
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
onChange: (value: T | null) => void;
@ -23,7 +23,7 @@ type UseModelComboboxReturn = {
noOptionsMessage: () => string;
};
export const useModelCombobox = <T extends AnyModelConfigEntity>(
export const useModelCombobox = <T extends AnyModelConfig>(
arg: UseModelComboboxArg<T>
): UseModelComboboxReturn => {
const { t } = useTranslation();

View File

@ -626,7 +626,7 @@ export const canvasSlice = createSlice({
},
extraReducers: (builder) => {
builder.addCase(modelChanged, (state, action) => {
if (action.meta.previousModel?.base_model === action.payload?.base_model) {
if (action.meta.previousModel?.base === action.payload?.base) {
// The base model hasn't changed, we don't need to optimize the size
return;
}

View File

@ -11,12 +11,7 @@ import { selectGenerationSlice } from 'features/parameters/store/generationSlice
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type {
ControlNetModelConfigEntity,
IPAdapterModelConfigEntity,
T2IAdapterModelConfigEntity,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
@ -29,21 +24,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
const models = useControlAdapterModelEntities(controlAdapterType);
const _onChange = useCallback(
(model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => {
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
if (!model) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
model: pick(model, 'base_model', 'model_name'),
model: pick(model, 'base', 'key'),
})
);
},
@ -57,7 +52,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base_model;
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
@ -73,7 +68,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
return (
<Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base_model !== model?.base_model}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== model?.base}>
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}

View File

@ -6,14 +6,14 @@ import { useCallback, useMemo } from 'react';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector((s) => s.generation.model?.base_model);
const baseModel = useAppSelector((s) => s.generation.model?.base);
const dispatch = useAppDispatch();
const models = useControlAdapterModels(type);
const firstModel = useMemo(() => {
// prefer to use a model that matches the base model
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0];
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
if (firstCompatibleModel) {
return firstCompatibleModel;

View File

@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (model.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
if (cn.model?.key.includes(modelSubstring)) {
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}

View File

@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfigEntity): boolean => {
(embedding: TextualInversionConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base_model;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfigEntity | null) => {
(embedding: TextualInversionConfig | null) => {
if (!embedding) {
return;
}

View File

@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => {
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem label={t('metadata.seed')} value={metadata.seed} onClick={handleRecallSeed} />
)}
{metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.model_name} onClick={handleRecallModel} />
{metadata.model !== undefined && metadata.model !== null && metadata.model.key && (
<ImageMetadataItem label={t('metadata.model')} value={metadata.model.key} onClick={handleRecallModel} />
)}
{metadata.width && (
<ImageMetadataItem label={t('metadata.width')} value={metadata.width} onClick={handleRecallWidth} />
@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => {
)}
<ImageMetadataItem
label={t('metadata.vae')}
value={metadata.vae?.model_name ?? 'Default'}
value={metadata.vae?.key ?? 'Default'}
onClick={handleRecallVaeModel}
/>
{metadata.steps && (
@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
value={`${lora.lora.key} - ${lora.weight}`}
onClick={handleRecallLoRA.bind(null, lora)}
/>
);
@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
value={`${controlnet.control_model?.key} - ${controlnet.control_weight}`}
onClick={handleRecallControlNet.bind(null, controlnet)}
/>
))}
@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
value={`${ipAdapter.ip_adapter_model?.key} - ${ipAdapter.weight}`}
onClick={handleRecallIPAdapter.bind(null, ipAdapter)}
/>
))}
@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => {
<ImageMetadataItem
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
value={`${t2iAdapter.t2i_adapter_model?.key} - ${t2iAdapter.weight}`}
onClick={handleRecallT2IAdapter.bind(null, t2iAdapter)}
/>
))}

View File

@ -44,7 +44,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CardHeader>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.model_name}
{lora.key}
</Text>
<Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />

View File

@ -18,7 +18,7 @@ export const LoRAList = memo(() => {
return (
<Flex flexWrap="wrap" gap={2}>
{lorasArray.map((lora) => (
<LoRACard key={lora.model_name} lora={lora} />
<LoRACard key={lora.key} lora={lora} />
))}
</Flex>
);

View File

@ -7,7 +7,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -19,7 +19,7 @@ const LoRASelect = () => {
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => {
const getIsDisabled = (lora: LoRAConfig): boolean => {
const isCompatible = currentBaseModel === lora.base_model;
const isAdded = Boolean(addedLoRAs[lora.id]);
const hasMainModel = Boolean(currentBaseModel);
@ -27,7 +27,7 @@ const LoRASelect = () => {
};
const _onChange = useCallback(
(lora: LoRAModelConfigEntity | null) => {
(lora: LoRAConfig | null) => {
if (!lora) {
return;
}

View File

@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
isEnabled?: boolean;
};
@ -29,40 +28,40 @@ export const loraSlice = createSlice({
name: 'lora',
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
loraAdded: (state, action: PayloadAction<LoRAConfig>) => {
const { key, base } = action.payload;
state.loras[key] = { key, base, ...defaultLoRAConfig };
},
loraRecalled: (state, action: PayloadAction<LoRAModelConfigEntity & { weight: number }>) => {
const { model_name, id, base_model, weight } = action.payload;
state.loras[id] = { id, model_name, base_model, weight, isEnabled: true };
loraRecalled: (state, action: PayloadAction<LoRAConfig & { weight: number }>) => {
const { key, base, weight } = action.payload;
state.loras[key] = { key, base, weight, isEnabled: true };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
const key = action.payload;
delete state.loras[key];
},
lorasCleared: (state) => {
state.loras = {};
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = state.loras[id];
loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => {
const { key, weight } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.weight = weight;
},
loraWeightReset: (state, action: PayloadAction<string>) => {
const id = action.payload;
const lora = state.loras[id];
const key = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}
lora.weight = defaultLoRAConfig.weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
const { id, isEnabled } = action.payload;
const lora = state.loras[id];
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'key' | 'isEnabled'>>) => {
const { key, isEnabled } = action.payload;
const lora = state.loras[key];
if (!lora) {
return;
}

View File

@ -3,9 +3,9 @@ import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type {
DiffusersModelConfigEntity,
LoRAModelConfigEntity,
MainModelConfigEntity,
DiffusersModelConfig,
LoRAConfig,
MainModelConfig,
} from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
@ -38,7 +38,7 @@ const ModelManagerPanel = () => {
};
type ModelEditProps = {
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
model: MainModelConfig | LoRAConfig | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
@ -50,7 +50,7 @@ const ModelEdit = (props: ModelEditProps) => {
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfigEntity} />;
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />;
}
if (model?.model_type === 'lora') {

View File

@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/endpoints/models';
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
type CheckpointModelEditProps = {
model: CheckpointModelConfigEntity;
model: CheckpointModelConfig;
};
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {

View File

@ -9,12 +9,12 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
model: DiffusersModelConfigEntity;
model: DiffusersModelConfig;
};
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {

View File

@ -8,12 +8,12 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
import type { LoRAConfig } from 'services/api/types';
type LoRAModelEditProps = {
model: LoRAModelConfigEntity;
model: LoRAConfig;
};
const LoRAModelEdit = (props: LoRAModelEditProps) => {
@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
control,
formState: { errors },
reset,
} = useForm<LoRAModelConfig>({
} = useForm<LoRAConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
const onSubmit = useCallback<SubmitHandler<LoRAConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
reset(payload as LoRAConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
<BaseModelSelect<LoRAConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => {
export default memo(ModelList);
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
data: EntityState<T, string> | undefined,
model_type: ModelType,
model_format: ModelFormat | undefined,
@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer';
type ModelListWrapperProps = {
title: string;
modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[];
modelList: MainModelConfig[] | LoRAConfig[];
selected: ModelListProps;
};

View File

@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models';
type ModelListItemProps = {
model: MainModelConfigEntity | LoRAModelConfigEntity;
model: MainModelConfig | LoRAConfig;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models';
import type { ControlNetConfig } from 'services/api/endpoints/models';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const { data, isLoading } = useGetControlNetModelsQuery();
const _onChange = useCallback(
(value: ControlNetModelConfigEntity | null) => {
(value: ControlNetConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models';
import type { IPAdapterConfig } from 'services/api/endpoints/models';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = (
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const _onChange = useCallback(
(value: IPAdapterModelConfigEntity | null) => {
(value: IPAdapterConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const _onChange = useCallback(
(value: LoRAModelConfigEntity | null) => {
(value: LoRAConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -9,7 +9,7 @@ import type {
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { SDXL_MAIN_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
const _onChange = useCallback(
(value: MainModelConfigEntity | null) => {
(value: MainModelConfig | null) => {
if (!value) {
return;
}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models';
import type { T2IAdapterConfig } from 'services/api/endpoints/models';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = (
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const _onChange = useCallback(
(value: T2IAdapterModelConfigEntity | null) => {
(value: T2IAdapterConfig | null) => {
if (!value) {
return;
}

View File

@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { FieldComponentProps } from './types';
@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery();
const _onChange = useCallback(
(value: VaeModelConfigEntity | null) => {
(value: VAEConfig | null) => {
if (!value) {
return;
}

View File

@ -67,11 +67,13 @@ export const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
key: z.string().min(1),
});
export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>;
export type ModelType = z.infer<typeof zModelType>;
export type ModelIdentifier = z.infer<typeof zModelIdentifier>;
export const zMainModelField = zModelIdentifier;
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
export const zMainModelField = zModelFieldBase;
export type MainModelField = z.infer<typeof zMainModelField>;
export const zSDXLRefinerModelField = zModelIdentifier;
@ -91,23 +93,23 @@ export const zSubModelType = z.enum([
]);
export type SubModelType = z.infer<typeof zSubModelType>;
export const zVAEModelField = zModelIdentifier;
export const zVAEModelField = zModelFieldBase;
export const zModelInfo = zModelIdentifier.extend({
submodel_type: zSubModelType.nullish(),
});
export type ModelInfo = z.infer<typeof zModelInfo>;
export const zLoRAModelField = zModelIdentifier;
export const zLoRAModelField = zModelFieldBase;
export type LoRAModelField = z.infer<typeof zLoRAModelField>;
export const zControlNetModelField = zModelIdentifier;
export const zControlNetModelField = zModelFieldBase;
export type ControlNetModelField = z.infer<typeof zControlNetModelField>;
export const zIPAdapterModelField = zModelIdentifier;
export const zIPAdapterModelField = zModelFieldBase;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zT2IAdapterModelField = zModelIdentifier;
export const zT2IAdapterModelField = zModelFieldBase;
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
export const zLoraInfo = zModelInfo.extend({

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validIPAdapters.length) {

View File

@ -28,6 +28,7 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -48,19 +49,19 @@ export const addLoRAsToGraph = (
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push({
lora: { model_name, base_model },
lora: { key },
weight,
});

View File

@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
// TODO(MM2): check base model
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = (
let currentLoraIndex = 0;
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const { key, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
lora: { key },
weight,
};
loraMetadata.push(
zLoRAMetadataItem.parse({
lora: { model_name, base_model },
lora: { key },
weight,
})
);

View File

@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base_model === state.generation.model?.base_model
(ca) => ca.model?.base === state.generation.model?.base
);
if (validT2IAdapters.length) {

View File

@ -19,7 +19,7 @@ export const buildCanvasGraph = (
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state);
@ -28,7 +28,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
@ -37,7 +37,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
@ -46,7 +46,7 @@ export const buildCanvasGraph = (
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base_model === 'sdxl') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);

View File

@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
});
}
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,

View File

@ -29,17 +29,17 @@ const ParamClipSkip = () => {
if (!model) {
return CLIP_SKIP_MAP['sd-1'].maxClip;
}
return CLIP_SKIP_MAP[model.base_model].maxClip;
return CLIP_SKIP_MAP[model.base].maxClip;
}, [model]);
const sliderMarks = useMemo(() => {
if (!model) {
return CLIP_SKIP_MAP['sd-1'].markers;
}
return CLIP_SKIP_MAP[model.base_model].markers;
return CLIP_SKIP_MAP[model.base].markers;
}, [model]);
if (model?.base_model === 'sdxl') {
if (model?.base === 'sdxl') {
return null;
}

View File

@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next';
export const ParamPositivePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector((s) => s.generation.positivePrompt);
const baseModel = useAppSelector((s) => s.generation.model)?.base_model;
const baseModel = useAppSelector((s) => s.generation.model)?.base;
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();

View File

@ -9,7 +9,7 @@ import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -26,7 +26,7 @@ const ParamMainModelSelect = () => {
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
}, [data, model]);
const _onChange = useCallback(
(model: MainModelConfigEntity | null) => {
(model: MainModelConfig | null) => {
if (!model) {
return;
}

View File

@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { VaeModelConfigEntity } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => {
const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback(
(vae: VaeModelConfigEntity): boolean => {
(vae: VAEConfig): boolean => {
const isCompatible = model?.base_model === vae.base_model;
const hasMainModel = Boolean(model?.base_model);
return !hasMainModel || !isCompatible;
@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => {
[model?.base_model]
);
const _onChange = useCallback(
(vae: VaeModelConfigEntity | null) => {
(vae: VAEConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
},
[dispatch]

View File

@ -464,17 +464,15 @@ export const useRecallParameters = () => {
return { lora: null, error: 'Invalid LoRA model' };
}
const { base_model, model_name } = loraMetadataItem.lora;
const { lora } = loraMetadataItem;
const matchingLoRA = loraModels
? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`)
: undefined;
const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined;
if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -520,17 +518,14 @@ export const useRecallParameters = () => {
controlnetMetadataItem;
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapterSelectors.selectById(
controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key)
: undefined;
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -597,17 +592,14 @@ export const useRecallParameters = () => {
t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapterSelectors.selectById(
t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
)
? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key)
: undefined;
if (!matchingT2IAdapterModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {
@ -672,17 +664,14 @@ export const useRecallParameters = () => {
const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapterSelectors.selectById(
ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key)
: undefined;
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base;
if (!isCompatibleBaseModel) {
return {

View File

@ -158,15 +158,15 @@ export const generationSlice = createSlice({
// Clamp ClipSkip Based On Selected Model
// TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved
// WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624
if (newModel.base_model === 'sdxl') {
if (newModel.base === 'sdxl') {
// We don't support clip skip for SDXL yet - it's not in the graphs
state.clipSkip = 0;
} else {
const { maxClip } = CLIP_SKIP_MAP[newModel.base_model];
const { maxClip } = CLIP_SKIP_MAP[newModel.base];
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
}
if (action.meta.previousModel?.base_model === newModel.base_model) {
if (action.meta.previousModel?.base === newModel.base) {
// The base model hasn't changed, we don't need to optimize the size
return;
}

View File

@ -1,5 +1,6 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import {
zBaseModel,
zControlNetModelField,
zIPAdapterModelField,
zLoRAModelField,
@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
// #endregion
// #region Model
export const zParameterModel = zMainModelField;
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
export type ParameterModel = z.infer<typeof zParameterModel>;
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
// #endregion
// #region SDXL Refiner Model
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField;
export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel });
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel =>
zParameterSDXLRefinerModel.safeParse(val).success;
// #endregion
// #region VAE Model
export const zParameterVAEModel = zVAEModelField;
export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel });
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel =>
zParameterVAEModel.safeParse(val).success;
// #endregion
// #region LoRA Model
export const zParameterLoRAModel = zLoRAModelField;
export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel });
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel =>
zParameterLoRAModel.safeParse(val).success;
// #endregion
// #region ControlNet Model
export const zParameterControlNetModel = zControlNetModelField;
export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel });
export type ParameterControlNetModel = z.infer<typeof zParameterLoRAModel>;
export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel =>
zParameterControlNetModel.safeParse(val).success;
// #endregion
// #region IP Adapter Model
export const zParameterIPAdapterModel = zIPAdapterModelField;
export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel });
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel =>
zParameterIPAdapterModel.safeParse(val).success;
// #endregion
// #region T2I Adapter Model
export const zParameterT2IAdapterModel = zT2IAdapterModelField;
export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel });
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel =>
zParameterT2IAdapterModel.safeParse(val).success;

View File

@ -1,12 +1,12 @@
import type { ModelIdentifier } from 'features/nodes/types/common';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
/**
* Gets the optimal dimension for a givel model, based on the model's base_model
* @param model The model identifier
* @returns The optimal dimension for the model
*/
export const getOptimalDimension = (model?: ModelIdentifier | null): number =>
model?.base_model === 'sdxl' ? 1024 : 512;
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
model?.base === 'sdxl' ? 1024 : 512;
const MIN_AREA_FACTOR = 0.8;
const MAX_AREA_FACTOR = 1.2;

View File

@ -7,12 +7,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfigEntity } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner';
const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner';
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
@ -20,7 +20,7 @@ const ParamSDXLRefinerModelSelect = () => {
const { t } = useTranslation();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const _onChange = useCallback(
(model: MainModelConfigEntity | null) => {
(model: MainModelConfig | null) => {
if (!model) {
dispatch(refinerModelChanged(null));
return;

View File

@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = {
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (generation.vae) {
let vaeBadge = generation.vae.model_name;
// TODO(MM2): Fetch the vae name
let vaeBadge = generation.vae.key;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}

View File

@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
// TODO(MM2): fetch model name
if (generation.model) {
accordionBadges.push(generation.model.model_name);
accordionBadges.push(generation.model.base_model);
accordionBadges.push(generation.model.key);
accordionBadges.push(generation.model.base);
}
return { loraTabBadges, accordionBadges };

View File

@ -56,7 +56,7 @@ const selector = createMemoizedSelector(
if (hrfEnabled) {
badges.push('HiRes Fix');
}
return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' };
return { badges, activeTabName, isSDXL: model?.base === 'sdxl' };
}
);

View File

@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = {
const ParametersPanel = () => {
const activeTabName = useAppSelector(activeTabNameSelector);
const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl');
const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl');
return (
<Flex w="full" h="full" flexDir="column" gap={2}>

View File

@ -1,64 +1,26 @@
import type { EntityState } from '@reduxjs/toolkit';
import type { EntityAdapter, EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { cloneDeep } from 'lodash-es';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
ControlNetConfig,
ImportModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
IPAdapterConfig,
LoRAConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
} from 'services/api/types';
import type { ApiTagDescription } from '..';
import type { ApiTagDescription, tagTypes } from '..';
import { api, LIST_TAG } from '..';
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
id: string;
};
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string;
};
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
export type AnyModelConfigEntity =
| MainModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| IPAdapterModelConfigEntity
| T2IAdapterModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -68,11 +30,11 @@ type UpdateMainModelArg = {
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
body: LoRAConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse = UpdateMainModelResponse;
@ -128,59 +90,71 @@ type CheckpointConfigsResponse =
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const loraModelsAdapter = createEntityAdapter<LoRAConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const controlNetModelsAdapter = createEntityAdapter<ControlNetConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const textualInversionModelsAdapter = createEntityAdapter<TextualInversionConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
export const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
export const vaeModelsAdapter = createEntityAdapter<VAEConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
export const getModelId = ({
base_model,
model_type,
model_name,
}: Pick<AnyModelConfig, 'base_model' | 'model_name' | 'model_type'>) => `${base_model}/${model_type}/${model_name}`;
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
const createModelEntities = <T extends AnyModelConfigEntity>(models: AnyModelConfig[]): T[] => {
const entityArray: T[] = [];
models.forEach((model) => {
const entity = {
...cloneDeep(model),
id: getModelId(model),
} as T;
entityArray.push(entity);
});
return entityArray;
};
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity, string>, BaseModelType[]>({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params = {
model_type: 'main',
@ -190,24 +164,8 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'MainModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: MainModelConfig[] }) => {
const entities = createModelEntities<MainModelConfigEntity>(response.models);
return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}),
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
query: ({ base_model, model_name, body }) => {
@ -277,26 +235,10 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity, string>, void>({
getLoRAModels: build.query<EntityState<LoRAConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'LoRAModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: LoRAModelConfig[] }) => {
const entities = createModelEntities<LoRAModelConfigEntity>(response.models);
return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<LoRAConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
@ -317,110 +259,30 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfigEntity, string>, void>({
getControlNetModels: build.query<EntityState<ControlNetConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'ControlNetModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(response.models);
return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<ControlNetConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetConfig>(controlNetModelsAdapter),
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfigEntity, string>, void>({
getIPAdapterModels: build.query<EntityState<IPAdapterConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'IPAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
const entities = createModelEntities<IPAdapterModelConfigEntity>(response.models);
return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<IPAdapterConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterConfig>(ipAdapterModelsAdapter),
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfigEntity, string>, void>({
getT2IAdapterModels: build.query<EntityState<T2IAdapterConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'T2IAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
const entities = createModelEntities<T2IAdapterModelConfigEntity>(response.models);
return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<T2IAdapterConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterConfig>(t2iAdapterModelsAdapter),
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity, string>, void>({
getVaeModels: build.query<EntityState<VAEConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'VaeModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: VaeModelConfig[] }) => {
const entities = createModelEntities<VaeModelConfigEntity>(response.models);
return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<VAEConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEConfig>(vaeModelsAdapter),
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfigEntity, string>, void>({
getTextualInversionModels: build.query<EntityState<TextualInversionConfig, string>, void>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'TextualInversionModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: TextualInversionModelConfig[] }) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(response.models);
return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities);
},
providesTags: buildProvidesTags<TextualInversionConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionConfig>(textualInversionModelsAdapter),
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {

View File

@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { components, paths } from 'services/api/schema';
import type { O } from 'ts-toolbelt';
import type { SetRequired } from 'type-fest';
export type S = components['schemas'];
@ -54,40 +55,34 @@ export type LoRAModelFormat = S['LoRAModelFormat'];
export type ControlNetModelField = S['ControlNetModelField'];
export type IPAdapterModelField = S['IPAdapterModelField'];
export type T2IAdapterModelField = S['T2IAdapterModelField'];
export type ModelsList = S['invokeai__app__api__routers__models__ModelsList'];
export type ControlField = S['ControlField'];
export type IPAdapterField = S['IPAdapterField'];
// Model Configs
export type LoRAModelConfig = S['LoRAModelConfig'];
export type VaeModelConfig = S['VaeModelConfig'];
export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig;
export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig'];
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig'];
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
export type TextualInversionModelConfig = S['TextualInversionModelConfig'];
export type DiffusersModelConfig =
| S['StableDiffusion1ModelDiffusersConfig']
| S['StableDiffusion2ModelDiffusersConfig']
| S['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
| S['StableDiffusion1ModelCheckpointConfig']
| S['StableDiffusion2ModelCheckpointConfig']
| S['StableDiffusionXLModelCheckpointConfig'];
// TODO(MM2): Can we make key required in the pydantic model?
type KeyRequired<T extends {key?: string}> = SetRequired<T, 'key'>;
export type LoRAConfig = KeyRequired<S['LoRAConfig']>;
// TODO(MM2): Can we rename this from Vae -> VAE
export type VAEConfig = KeyRequired<S['VaeCheckpointConfig']> | KeyRequired<S['VaeDiffusersConfig']>;
export type ControlNetConfig = KeyRequired<S['ControlNetDiffusersConfig']> | KeyRequired<S['ControlNetCheckpointConfig']>;
export type IPAdapterConfig = KeyRequired<S['IPAdapterConfig']>;
// TODO(MM2): Can we rename this to T2IAdapterConfig
export type T2IAdapterConfig = KeyRequired<S['T2IConfig']>;
export type TextualInversionConfig = KeyRequired<S['TextualInversionConfig']>;
export type DiffusersModelConfig = KeyRequired<S['MainDiffusersConfig']>;
export type CheckpointModelConfig = KeyRequired<S['MainCheckpointConfig']>;
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig
| ControlNetModelConfig
| IPAdapterModelConfig
| T2IAdapterModelConfig
| TextualInversionModelConfig
| LoRAConfig
| VAEConfig
| ControlNetConfig
| IPAdapterConfig
| T2IAdapterConfig
| TextualInversionConfig
| MainModelConfig;
export type MergeModelConfig = S['Body_merge_models'];
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
// Graphs