mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): update node editor to use model object format
similar to the previous commit, update the node editor to not just store models as strings - instead, store the model object. the model select components in nodes are now just kinda copy-pastes over the linear UI versions of the same components, but they were different enough that we can't just share them. i explored adding some props to override the linear ui components' logic, but it was too brittle. so just copy/paste.
This commit is contained in:
parent
a071873327
commit
8dd4ca5723
@ -1,15 +1,17 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
@ -20,17 +22,10 @@ const LoRAModelInputFieldComponent = (
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const lora = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
|
||||
[loraModels?.entities, loraModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!loraModels) {
|
||||
return [];
|
||||
@ -38,62 +33,78 @@ const LoRAModelInputFieldComponent = (
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(loraModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
forEach(loraModels.entities, (lora, id) => {
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
label: lora.model_name,
|
||||
group: MODEL_TYPE_MAP[lora.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [loraModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
const selectedLoRAModel = useMemo(
|
||||
() =>
|
||||
loraModels?.entities[`${lora?.base_model}/lora/${lora?.model_name}`] ??
|
||||
null,
|
||||
[loraModels?.entities, lora?.base_model, lora?.model_name]
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newLoRAModel = modelIdToLoRAModelParam(v);
|
||||
|
||||
if (!newLoRAModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
value: newLoRAModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && loraModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstLora = loraModels?.ids[0];
|
||||
|
||||
if (!isString(firstLora)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstLora);
|
||||
}, [field.value, handleValueChanged, loraModels?.ids]);
|
||||
if (loraModels?.ids.length === 0) {
|
||||
return (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||
No LoRAs Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
value={selectedLoRAModel?.id ?? null}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
selectedLoRAModel?.base_model &&
|
||||
MODEL_TYPE_MAP[selectedLoRAModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
placeholder={data.length > 0 ? 'Select a LoRA' : 'No LoRAs available'}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
nothingFound="No matching LoRAs"
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, item: SelectItem) =>
|
||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,28 +1,29 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
MainModelInputFieldValue,
|
||||
ModelInputFieldTemplate,
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
props: FieldComponentProps<MainModelInputFieldValue, ModelInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
@ -46,52 +47,58 @@ const ModelInputFieldComponent = (
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
|
||||
[mainModels?.entities, mainModels?.ids, field.value]
|
||||
() =>
|
||||
mainModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
|
||||
);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
value: newModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && mainModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstModel);
|
||||
}, [field.value, handleValueChanged, mainModels?.ids]);
|
||||
|
||||
return (
|
||||
return isLoading ? (
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,16 +0,0 @@
|
||||
import {
|
||||
UNetInputFieldTemplate,
|
||||
UNetInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const UNetInputFieldComponent = (
|
||||
props: FieldComponentProps<UNetInputFieldValue, UNetInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(UNetInputFieldComponent);
|
@ -1,16 +0,0 @@
|
||||
import {
|
||||
VaeInputFieldTemplate,
|
||||
VaeInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const VaeInputFieldComponent = (
|
||||
props: FieldComponentProps<VaeInputFieldValue, VaeInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(VaeInputFieldComponent);
|
@ -1,14 +1,16 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
@ -20,73 +22,83 @@ const VaeModelInputFieldComponent = (
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const vae = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'default',
|
||||
label: 'Default',
|
||||
group: 'Default',
|
||||
},
|
||||
];
|
||||
|
||||
forEach(vaeModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
forEach(vaeModels.entities, (vae, id) => {
|
||||
if (!vae) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
label: vae.model_name,
|
||||
group: MODEL_TYPE_MAP[vae.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||
}, [vaeModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedVaeModel = useMemo(
|
||||
() =>
|
||||
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
|
||||
[vaeModels?.entities, vae]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newVaeModel = modelIdToVAEModelParam(v);
|
||||
|
||||
if (!newVaeModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
value: newVaeModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
handleValueChanged('auto');
|
||||
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
selectedVaeModel?.base_model &&
|
||||
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
value={selectedVaeModel?.id ?? 'default'}
|
||||
placeholder="Default"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
onChange={handleChangeModel}
|
||||
disabled={data.length === 0}
|
||||
clearable
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,5 +1,10 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, uniqBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
@ -73,7 +78,10 @@ const nodesSlice = createSlice({
|
||||
| ImageField
|
||||
| RgbaColor
|
||||
| undefined
|
||||
| ImageField[];
|
||||
| ImageField[]
|
||||
| MainModelParam
|
||||
| VaeModelParam
|
||||
| LoRAModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
|
@ -1,3 +1,8 @@
|
||||
import {
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Graph, ImageDTO, ImageField } from 'services/api/types';
|
||||
@ -92,7 +97,7 @@ export type InputFieldValue =
|
||||
| VaeInputFieldValue
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| MainModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
@ -229,19 +234,19 @@ export type ImageCollectionInputFieldValue = FieldValueBase & {
|
||||
value?: ImageField[];
|
||||
};
|
||||
|
||||
export type ModelInputFieldValue = FieldValueBase & {
|
||||
export type MainModelInputFieldValue = FieldValueBase & {
|
||||
type: 'model';
|
||||
value?: string;
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: string;
|
||||
value?: VaeModelParam;
|
||||
};
|
||||
|
||||
export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||
type: 'lora_model';
|
||||
value?: string;
|
||||
value?: LoRAModelParam;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
|
@ -1,8 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { InputFieldValue } from 'features/nodes/types/types';
|
||||
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
@ -27,24 +24,6 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'model') {
|
||||
if (field.value) {
|
||||
return modelIdToMainModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'vae_model') {
|
||||
if (field.value) {
|
||||
return modelIdToVAEModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'lora_model') {
|
||||
if (field.value) {
|
||||
return modelIdToLoRAModelParam(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
return field.value;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user