mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): update all components and logic to use enriched ModelIdentifierField
This commit is contained in:
parent
4433b78e59
commit
133c90e116
@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
|
||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
||||
}, []);
|
||||
|
||||
return props.children;
|
||||
|
@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
|
@ -1,6 +1,6 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
optionsFilter?: (model: T) => boolean;
|
||||
|
@ -1,7 +1,7 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
modelFilter?: (model: T) => boolean;
|
||||
isModelDisabled?: (model: T) => boolean;
|
||||
|
@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
|
||||
const model = { key: modelConfig.key, base: modelConfig.base };
|
||||
const model = zModelIdentifierField.parse(modelConfig);
|
||||
|
||||
if (!isControlNetOrT2IAdapter(cn)) {
|
||||
caAdapter.updateOne(state, { id, changes: { model } });
|
||||
|
@ -1,7 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
@ -31,7 +31,7 @@ export const loraSlice = createSlice({
|
||||
initialState: initialLoraState,
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||
const model = getModelKeyAndBase(action.payload);
|
||||
const model = zModelIdentifierField.parse(action.payload);
|
||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||
},
|
||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||
|
@ -13,13 +13,13 @@ import type {
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { validators } from 'features/metadata/util/validators';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { t } from 'i18next';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
import { recallers } from './recallers';
|
||||
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.key);
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
|
||||
}
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
};
|
||||
|
||||
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
|
||||
key: modelConfig.key,
|
||||
base: modelConfig.base,
|
||||
});
|
||||
|
@ -13,12 +13,7 @@ import type {
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
zModelIdentifierWithBase,
|
||||
zT2IAdapterField,
|
||||
} from 'features/nodes/types/common';
|
||||
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
|
||||
import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
||||
const model = await getProperty(metadata, 'model', undefined);
|
||||
const key = await getModelKey(model, 'main');
|
||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
|
||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||
const key = await getModelKey(refiner_model, 'main');
|
||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
||||
const vae = await getProperty(metadata, 'vae', undefined);
|
||||
const key = await getModelKey(vae, 'vae');
|
||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
return {
|
||||
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
||||
model: zModelIdentifierField.parse(loraModelConfig),
|
||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||
isEnabled: true,
|
||||
};
|
||||
@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
||||
const controlNet: ControlNetConfigMetadata = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
model: zModelIdentifierField.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
||||
@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
model: zModelIdentifierField.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||
@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import type {
|
||||
BaseModel,
|
||||
BoardField,
|
||||
Classification,
|
||||
ColorField,
|
||||
@ -7,6 +6,7 @@ import type {
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
IPAdapterField,
|
||||
ModelIdentifierField,
|
||||
ProgressImage,
|
||||
SchedulerField,
|
||||
T2IAdapterField,
|
||||
@ -33,10 +33,9 @@ describe('Common types', () => {
|
||||
test('T2IAdapterField', () => assert<Equals<T2IAdapterField, S['T2IAdapterField']>>());
|
||||
|
||||
// Model component types
|
||||
test('BaseModel', () => assert<Equals<BaseModel, S['BaseModelType']>>());
|
||||
test('ModelIdentifier', () => assert<Equals<ModelIdentifierField, S['ModelIdentifierField']>>());
|
||||
|
||||
// Misc types
|
||||
// @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet
|
||||
test('ProgressImage', () => assert<Equals<ProgressImage, S['ProgressImage']>>());
|
||||
test('ImageOutput', () => assert<Equals<ImageOutput, S['ImageOutput']>>());
|
||||
test('Classification', () => assert<Equals<Classification, S['Classification']>>());
|
||||
|
@ -55,6 +55,17 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'controlnet',
|
||||
't2i_adapter',
|
||||
'ip_adapter',
|
||||
'embedding',
|
||||
'onnx',
|
||||
'clip_vision',
|
||||
]);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
@ -67,26 +78,25 @@ const zSubModelType = z.enum([
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
|
||||
const zModelIdentifier = z.object({
|
||||
export const zModelIdentifierField = z.object({
|
||||
key: z.string().min(1),
|
||||
hash: z.string().min(1),
|
||||
name: z.string().min(1),
|
||||
base: zBaseModel,
|
||||
type: zModelType,
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifierField =>
|
||||
zModelIdentifierField.safeParse(field).success;
|
||||
export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 =>
|
||||
zModelIdentifierV2.safeParse(field).success;
|
||||
const zModelFieldBase = zModelIdentifier;
|
||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
type ModelIdentifier = z.infer<typeof zModelIdentifier>;
|
||||
export type ModelIdentifierWithBase = z.infer<typeof zModelIdentifierWithBase>;
|
||||
export type ModelIdentifierField = z.infer<typeof zModelIdentifierField>;
|
||||
// #endregion
|
||||
|
||||
// #region Control Adapters
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zModelFieldBase,
|
||||
control_model: zModelIdentifierField,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
@ -97,7 +107,7 @@ export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zModelFieldBase,
|
||||
ip_adapter_model: zModelIdentifierField,
|
||||
weight: z.number(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
@ -106,7 +116,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zModelFieldBase,
|
||||
t2i_adapter_model: zModelIdentifierField,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zBoardField, zColorField, zImageField, zModelIdentifierWithBase, zSchedulerField } from './common';
|
||||
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
|
||||
|
||||
/**
|
||||
* zod schemas & inferred types for fields.
|
||||
@ -277,7 +277,7 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT
|
||||
const zMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('MainModelField'),
|
||||
});
|
||||
export const zMainModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zMainModelFieldValue = zModelIdentifierField.optional();
|
||||
const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zMainModelFieldValue,
|
||||
});
|
||||
@ -348,7 +348,7 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
});
|
||||
export const zVAEModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zVAEModelFieldValue,
|
||||
});
|
||||
@ -372,7 +372,7 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField
|
||||
const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
});
|
||||
export const zLoRAModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLoRAModelFieldValue,
|
||||
});
|
||||
@ -396,7 +396,7 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
});
|
||||
export const zControlNetModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlNetModelFieldValue,
|
||||
});
|
||||
@ -420,7 +420,7 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro
|
||||
const zIPAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IPAdapterModelField'),
|
||||
});
|
||||
export const zIPAdapterModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zIPAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIPAdapterModelFieldValue,
|
||||
});
|
||||
@ -444,7 +444,7 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt
|
||||
const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
});
|
||||
export const zT2IAdapterModelFieldValue = zModelIdentifierWithBase.optional();
|
||||
export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional();
|
||||
const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zT2IAdapterModelFieldValue,
|
||||
});
|
||||
|
@ -1,15 +1,10 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type LoRALoaderInvocation,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
@ -49,19 +44,18 @@ export const addLoRAsToGraph = async (
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraLoaderNode: LoRALoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { key },
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({
|
||||
model: getModelMetadataField(modelConfig),
|
||||
model: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -1,12 +1,7 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SDXLLoRALoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
|
||||
|
||||
import {
|
||||
LORA_LOADER,
|
||||
@ -16,7 +11,7 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
@ -63,20 +58,18 @@ export const addSDXLLoRAsToGraph = async (
|
||||
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraLoaderNode: SDXLLoRALoaderInvocation = {
|
||||
type: 'sdxl_lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
is_intermediate: true,
|
||||
lora: { key },
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
||||
loraMetadata.push({ model: parsedModel, weight });
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
@ -3,6 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
@ -24,7 +25,11 @@ const ParamMainModelSelect = () => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected({ key: model.key, base: model.base }));
|
||||
try {
|
||||
dispatch(modelSelected(zModelIdentifierField.parse(model)));
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -30,7 +30,7 @@ const ParamVAEModelSelect = () => {
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
|
||||
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { zModelIdentifierWithBase, zSchedulerField } from 'features/nodes/types/common';
|
||||
import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -92,37 +92,37 @@ export const isParameterHeight = (val: unknown): val is ParameterHeight => zPara
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
export const zParameterModel = zModelIdentifierWithBase;
|
||||
export const zParameterModel = zModelIdentifierField;
|
||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||
// #endregion
|
||||
|
||||
// #region SDXL Refiner Model
|
||||
const zParameterSDXLRefinerModel = zModelIdentifierWithBase;
|
||||
const zParameterSDXLRefinerModel = zModelIdentifierField;
|
||||
export type ParameterSDXLRefinerModel = z.infer<typeof zParameterSDXLRefinerModel>;
|
||||
// #endregion
|
||||
|
||||
// #region VAE Model
|
||||
export const zParameterVAEModel = zModelIdentifierWithBase;
|
||||
export const zParameterVAEModel = zModelIdentifierField;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
const zParameterLoRAModel = zModelIdentifierWithBase;
|
||||
const zParameterLoRAModel = zModelIdentifierField;
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNet Model
|
||||
const zParameterControlNetModel = zModelIdentifierWithBase;
|
||||
const zParameterControlNetModel = zModelIdentifierField;
|
||||
export type ParameterControlNetModel = z.infer<typeof zParameterControlNetModel>;
|
||||
// #endregion
|
||||
|
||||
// #region IP Adapter Model
|
||||
const zParameterIPAdapterModel = zModelIdentifierWithBase;
|
||||
const zParameterIPAdapterModel = zModelIdentifierField;
|
||||
export type ParameterIPAdapterModel = z.infer<typeof zParameterIPAdapterModel>;
|
||||
// #endregion
|
||||
|
||||
// #region T2I Adapter Model
|
||||
const zParameterT2IAdapterModel = zModelIdentifierWithBase;
|
||||
const zParameterT2IAdapterModel = zModelIdentifierField;
|
||||
export type ParameterT2IAdapterModel = z.infer<typeof zParameterT2IAdapterModel>;
|
||||
// #endregion
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
|
||||
/**
|
||||
* Gets the optimal dimension for a givel model, based on the model's base_model
|
||||
* @param model The model identifier
|
||||
* @returns The optimal dimension for the model
|
||||
*/
|
||||
export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number =>
|
||||
export const getOptimalDimension = (model?: ModelIdentifierField | null): number =>
|
||||
model?.base === 'sdxl' ? 1024 : 512;
|
||||
|
||||
const MIN_AREA_FACTOR = 0.8;
|
||||
|
@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -26,7 +26,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
|
||||
dispatch(refinerModelChanged(zModelIdentifierField.parse(model)));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
Loading…
x
Reference in New Issue
Block a user