feat(ui): update node editor to use model object format

similar to the previous commit, update the node editor to not just store models as strings - instead, store the model object.

the model select components in nodes are now just kinda copy-pastes over the linear UI versions of the same components, but they were different enough that we can't just share them.

i explored adding some props to override the linear ui components' logic, but it was too brittle. so just copy/paste.
This commit is contained in:
psychedelicious 2023-07-14 14:53:33 +10:00
parent a071873327
commit 8dd4ca5723
8 changed files with 141 additions and 151 deletions

View File

@ -1,15 +1,17 @@
import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
VaeModelInputFieldValue, VaeModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es'; import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { forEach } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { memo, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
@ -20,17 +22,10 @@ const LoRAModelInputFieldComponent = (
> >
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const lora = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: loraModels } = useGetLoRAModelsQuery(); const { data: loraModels } = useGetLoRAModelsQuery();
const selectedModel = useMemo(
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
[loraModels?.entities, loraModels?.ids, field.value]
);
const data = useMemo(() => { const data = useMemo(() => {
if (!loraModels) { if (!loraModels) {
return []; return [];
@ -38,62 +33,78 @@ const LoRAModelInputFieldComponent = (
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(loraModels.entities, (model, id) => { forEach(loraModels.entities, (lora, id) => {
if (!model) { if (!lora) {
return; return;
} }
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: lora.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[lora.base_model],
}); });
}); });
return data; return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loraModels]); }, [loraModels]);
const handleValueChanged = useCallback( const selectedLoRAModel = useMemo(
() =>
loraModels?.entities[`${lora?.base_model}/lora/${lora?.model_name}`] ??
null,
[loraModels?.entities, lora?.base_model, lora?.model_name]
);
const handleChange = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
return; return;
} }
const newLoRAModel = modelIdToLoRAModelParam(v);
if (!newLoRAModel) {
return;
}
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: v, value: newLoRAModel,
}) })
); );
}, },
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
useEffect(() => { if (loraModels?.ids.length === 0) {
if (field.value && loraModels?.ids.includes(field.value)) { return (
return; <Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded
</Text>
</Flex>
);
} }
const firstLora = loraModels?.ids[0];
if (!isString(firstLora)) {
return;
}
handleValueChanged(firstLora);
}, [field.value, handleValueChanged, loraModels?.ids]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} value={selectedLoRAModel?.id ?? null}
label={ label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model] selectedLoRAModel?.base_model &&
MODEL_TYPE_MAP[selectedLoRAModel?.base_model]
} }
value={field.value} placeholder={data.length > 0 ? 'Select a LoRA' : 'No LoRAs available'}
placeholder="Pick one"
data={data} data={data}
onChange={handleValueChanged} nothingFound="No matching LoRAs"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/> />
); );
}; };

View File

@ -1,28 +1,29 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
MainModelInputFieldValue,
ModelInputFieldTemplate, ModelInputFieldTemplate,
ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<MainModelInputFieldValue, ModelInputFieldTemplate>
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: mainModels } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery();
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {
@ -46,52 +47,58 @@ const ModelInputFieldComponent = (
return data; return data;
}, [mainModels]); }, [mainModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo( const selectedModel = useMemo(
() => mainModels?.entities[field.value ?? mainModels.ids[0]], () =>
[mainModels?.entities, mainModels?.ids, field.value] mainModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
); );
const handleValueChanged = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
return; return;
} }
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: v, value: newModel,
}) })
); );
}, },
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
useEffect(() => { return isLoading ? (
if (field.value && mainModels?.ids.includes(field.value)) { <IAIMantineSelect
return; label={t('modelManager.model')}
} placeholder="Loading..."
disabled={true}
const firstModel = mainModels?.ids[0]; data={[]}
/>
if (!isString(firstModel)) { ) : (
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, mainModels?.ids]);
return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model] selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
} }
value={field.value} value={selectedModel?.id}
placeholder="Pick one" placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data} data={data}
onChange={handleValueChanged} error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/> />
); );
}; };

View File

@ -1,16 +0,0 @@
import {
UNetInputFieldTemplate,
UNetInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const UNetInputFieldComponent = (
props: FieldComponentProps<UNetInputFieldValue, UNetInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(UNetInputFieldComponent);

View File

@ -1,16 +0,0 @@
import {
VaeInputFieldTemplate,
VaeInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const VaeInputFieldComponent = (
props: FieldComponentProps<VaeInputFieldValue, VaeInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(VaeInputFieldComponent);

View File

@ -1,14 +1,16 @@
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
VaeModelInputFieldValue, VaeModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
@ -20,73 +22,83 @@ const VaeModelInputFieldComponent = (
> >
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const vae = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();
const selectedModel = useMemo(
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
[vaeModels?.entities, vaeModels?.ids, field.value]
);
const data = useMemo(() => { const data = useMemo(() => {
if (!vaeModels) { if (!vaeModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [
{
value: 'default',
label: 'Default',
group: 'Default',
},
];
forEach(vaeModels.entities, (model, id) => { forEach(vaeModels.entities, (vae, id) => {
if (!model) { if (!vae) {
return; return;
} }
data.push({ data.push({
value: id, value: id,
label: model.model_name, label: vae.model_name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[vae.base_model],
}); });
}); });
return data; return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels]); }, [vaeModels]);
const handleValueChanged = useCallback( // grab the full model entity from the RTK Query cache
const selectedVaeModel = useMemo(
() =>
vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
[vaeModels?.entities, vae]
);
const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
return; return;
} }
const newVaeModel = modelIdToVAEModelParam(v);
if (!newVaeModel) {
return;
}
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: v, value: newVaeModel,
}) })
); );
}, },
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
useEffect(() => {
if (field.value && vaeModels?.ids.includes(field.value)) {
return;
}
handleValueChanged('auto');
}, [field.value, handleValueChanged, vaeModels?.ids]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={ label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model] selectedVaeModel?.base_model &&
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
} }
value={field.value} value={selectedVaeModel?.id ?? 'default'}
placeholder="Pick one" placeholder="Default"
data={data} data={data}
onChange={handleValueChanged} onChange={handleChangeModel}
disabled={data.length === 0}
clearable
/> />
); );
}; };

View File

@ -1,5 +1,10 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import {
LoRAModelParam,
MainModelParam,
VaeModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, uniqBy } from 'lodash-es'; import { cloneDeep, uniqBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
@ -73,7 +78,10 @@ const nodesSlice = createSlice({
| ImageField | ImageField
| RgbaColor | RgbaColor
| undefined | undefined
| ImageField[]; | ImageField[]
| MainModelParam
| VaeModelParam
| LoRAModelParam;
}> }>
) => { ) => {
const { nodeId, fieldName, value } = action.payload; const { nodeId, fieldName, value } = action.payload;

View File

@ -1,3 +1,8 @@
import {
LoRAModelParam,
MainModelParam,
VaeModelParam,
} from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Graph, ImageDTO, ImageField } from 'services/api/types'; import { Graph, ImageDTO, ImageField } from 'services/api/types';
@ -92,7 +97,7 @@ export type InputFieldValue =
| VaeInputFieldValue | VaeInputFieldValue
| ControlInputFieldValue | ControlInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | MainModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
@ -229,19 +234,19 @@ export type ImageCollectionInputFieldValue = FieldValueBase & {
value?: ImageField[]; value?: ImageField[];
}; };
export type ModelInputFieldValue = FieldValueBase & { export type MainModelInputFieldValue = FieldValueBase & {
type: 'model'; type: 'model';
value?: string; value?: MainModelParam;
}; };
export type VaeModelInputFieldValue = FieldValueBase & { export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model'; type: 'vae_model';
value?: string; value?: VaeModelParam;
}; };
export type LoRAModelInputFieldValue = FieldValueBase & { export type LoRAModelInputFieldValue = FieldValueBase & {
type: 'lora_model'; type: 'lora_model';
value?: string; value?: LoRAModelParam;
}; };
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {

View File

@ -1,8 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { cloneDeep, omit, reduce } from 'lodash-es'; import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
@ -27,24 +24,6 @@ export const parseFieldValue = (field: InputFieldValue) => {
} }
} }
if (field.type === 'model') {
if (field.value) {
return modelIdToMainModelParam(field.value);
}
}
if (field.type === 'vae_model') {
if (field.value) {
return modelIdToVAEModelParam(field.value);
}
}
if (field.type === 'lora_model') {
if (field.value) {
return modelIdToLoRAModelParam(field.value);
}
}
return field.value; return field.value;
}; };