mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix a lot of model-related crashes/bugs
We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`. This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`. Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing. This commit refactors the state to use the object form of models. There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated. Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
@ -6,7 +6,7 @@ import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = (
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
|
@ -7,7 +7,7 @@ import {
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -39,7 +39,7 @@ const ModelInputFieldComponent = (
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
@ -86,8 +86,7 @@ const ModelInputFieldComponent = (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
|
@ -6,7 +6,7 @@ import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -46,7 +46,7 @@ const VaeModelInputFieldComponent = (
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
@ -81,8 +81,7 @@ const VaeModelInputFieldComponent = (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
|
@ -5,7 +5,6 @@ import {
|
||||
LoraLoaderInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
@ -55,23 +54,22 @@ export const addLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
forEach(loras, (lora) => {
|
||||
const { id, name, weight } = lora;
|
||||
const loraField = modelIdToLoRAModelField(id);
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
|
||||
'.',
|
||||
'_'
|
||||
)}`;
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
lora: loraField,
|
||||
lora,
|
||||
weight,
|
||||
};
|
||||
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.loras.push({ lora: loraField, weight });
|
||||
metadataAccumulator.loras.push({
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
});
|
||||
}
|
||||
|
||||
// add to graph
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
import {
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
@ -19,7 +18,6 @@ export const addVAEToGraph = (
|
||||
graph: NonNullableGraph
|
||||
): void => {
|
||||
const { vae } = state.generation;
|
||||
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
||||
|
||||
const isAutoVae = !vae;
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
@ -30,7 +28,7 @@ export const addVAEToGraph = (
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model,
|
||||
vae_model: vae,
|
||||
};
|
||||
}
|
||||
|
||||
@ -74,6 +72,6 @@ export const addVAEToGraph = (
|
||||
}
|
||||
|
||||
if (vae && metadataAccumulator) {
|
||||
metadataAccumulator.vae = vae_model;
|
||||
metadataAccumulator.vae = vae;
|
||||
}
|
||||
};
|
||||
|
@ -1,12 +1,12 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { InputFieldValue } from 'features/nodes/types/types';
|
||||
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
|
||||
if (field.type === 'model') {
|
||||
if (field.value) {
|
||||
return modelIdToMainModelField(field.value);
|
||||
return modelIdToMainModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'vae_model') {
|
||||
if (field.value) {
|
||||
return modelIdToVAEModelField(field.value);
|
||||
return modelIdToVAEModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'lora_model') {
|
||||
if (field.value) {
|
||||
return modelIdToLoRAModelField(field.value);
|
||||
return modelIdToLoRAModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,12 +0,0 @@
|
||||
import { BaseModelType, LoRAModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
|
||||
const [base_model, model_type, model_name] = loraId.split('/');
|
||||
|
||||
const field: LoRAModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,16 +0,0 @@
|
||||
import { BaseModelType, MainModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: MainModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,16 +0,0 @@
|
||||
import { BaseModelType, VAEModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: VAEModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
Reference in New Issue
Block a user