feat(ui): fix a lot of model-related crashes/bugs

We were storing all types of models by their model ID, which is a format like `sd-1/main/deliberate`.

This meant we had to do a lot of extra parsing, because nodes actually wants something like `{base_model: 'sd-1', model_name: 'deliberate'}`.

Some of this parsing was done with zod's error-throwing `parse()` method, and in other places it was done with brittle string parsing.

This commit refactors the state to use the object form of models.

There is still a bit of string parsing done in the to construct the ID from the object form, but it's far less complicated.

Also, the zod parsing is now done using `safeParse()`, which does not throw. This requires a few more conditional checks, but should prevent further crashes.
This commit is contained in:
psychedelicious
2023-07-14 14:14:03 +10:00
parent 14587464d5
commit a071873327
34 changed files with 342 additions and 201 deletions

View File

@ -6,7 +6,7 @@ import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -46,7 +46,7 @@ const LoRAModelInputFieldComponent = (
data.push({
value: id,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
group: MODEL_TYPE_MAP[model.base_model],
});
});
@ -88,8 +88,7 @@ const LoRAModelInputFieldComponent = (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"

View File

@ -7,7 +7,7 @@ import {
} from 'features/nodes/types/types';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -39,7 +39,7 @@ const ModelInputFieldComponent = (
data.push({
value: id,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
group: MODEL_TYPE_MAP[model.base_model],
});
});
@ -86,8 +86,7 @@ const ModelInputFieldComponent = (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"

View File

@ -6,7 +6,7 @@ import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -46,7 +46,7 @@ const VaeModelInputFieldComponent = (
data.push({
value: id,
label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model],
group: MODEL_TYPE_MAP[model.base_model],
});
});
@ -81,8 +81,7 @@ const VaeModelInputFieldComponent = (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"

View File

@ -5,7 +5,6 @@ import {
LoraLoaderInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import {
CLIP_SKIP,
LORA_LOADER,
@ -55,23 +54,22 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { id, name, weight } = lora;
const loraField = modelIdToLoRAModelField(id);
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
'.',
'_'
)}`;
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
lora: loraField,
lora,
weight,
};
// add the lora to the metadata accumulator
if (metadataAccumulator) {
metadataAccumulator.loras.push({ lora: loraField, weight });
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
});
}
// add to graph

View File

@ -1,7 +1,6 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
@ -19,7 +18,6 @@ export const addVAEToGraph = (
graph: NonNullableGraph
): void => {
const { vae } = state.generation;
const vae_model = modelIdToVAEModelField(vae?.id || '');
const isAutoVae = !vae;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
@ -30,7 +28,7 @@ export const addVAEToGraph = (
graph.nodes[VAE_LOADER] = {
type: 'vae_loader',
id: VAE_LOADER,
vae_model,
vae_model: vae,
};
}
@ -74,6 +72,6 @@ export const addVAEToGraph = (
}
if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae_model;
metadataAccumulator.vae = vae;
}
};

View File

@ -1,12 +1,12 @@
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/**
* We need to do special handling for some fields
@ -29,19 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') {
if (field.value) {
return modelIdToMainModelField(field.value);
return modelIdToMainModelParam(field.value);
}
}
if (field.type === 'vae_model') {
if (field.value) {
return modelIdToVAEModelField(field.value);
return modelIdToVAEModelParam(field.value);
}
}
if (field.type === 'lora_model') {
if (field.value) {
return modelIdToLoRAModelField(field.value);
return modelIdToLoRAModelParam(field.value);
}
}

View File

@ -1,12 +0,0 @@
import { BaseModelType, LoRAModelField } from 'services/api/types';
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
const [base_model, model_type, model_name] = loraId.split('/');
const field: LoRAModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,16 +0,0 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,16 +0,0 @@
import { BaseModelType, VAEModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: VAEModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};