mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Addvae_model
input type front end
This commit is contained in:
parent
5ad6b64721
commit
38660a2162
@ -3,20 +3,21 @@ import { memo } from 'react';
|
|||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
|
||||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
|
||||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||||
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
|
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||||
|
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||||
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
|
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||||
|
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
type InputFieldComponentProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||||
|
return (
|
||||||
|
<VaeModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
if (type === 'array' && template.type === 'array') {
|
||||||
return (
|
return (
|
||||||
<ArrayInputFieldComponent
|
<ArrayInputFieldComponent
|
||||||
|
@ -0,0 +1,104 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/system/components/VAESelect';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const VaeModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
VaeModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: vaeModels } = useListModelsQuery({
|
||||||
|
model_type: 'vae',
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||||
|
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!vaeModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [
|
||||||
|
{
|
||||||
|
value: 'auto',
|
||||||
|
label: 'auto',
|
||||||
|
group: 'Automatic',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
forEach(vaeModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [vaeModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleValueChanged('none');
|
||||||
|
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(VaeModelInputFieldComponent);
|
@ -64,6 +64,7 @@ export type FieldType =
|
|||||||
| 'vae'
|
| 'vae'
|
||||||
| 'control'
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
|
| 'vae_model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
| 'color'
|
| 'color'
|
||||||
@ -91,6 +92,7 @@ export type InputFieldValue =
|
|||||||
| ControlInputFieldValue
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| ModelInputFieldValue
|
| ModelInputFieldValue
|
||||||
|
| VaeModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
| ItemInputFieldValue
|
| ItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
@ -116,6 +118,7 @@ export type InputFieldTemplate =
|
|||||||
| ControlInputFieldTemplate
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
|
| VaeModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
| ItemInputFieldTemplate
|
| ItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: string;
|
value?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'vae_model';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
type: 'array';
|
type: 'array';
|
||||||
value?: (string | number)[];
|
value?: (string | number)[];
|
||||||
@ -337,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'model';
|
type: 'model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'vae_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'array';
|
type: 'array';
|
||||||
|
@ -24,6 +24,7 @@ import {
|
|||||||
TypeHints,
|
TypeHints,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildVaeModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
|
||||||
|
const template: VaeModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'vae_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['vae_model'].includes(fieldType)) {
|
||||||
|
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['enum'].includes(fieldType)) {
|
if (['enum'].includes(fieldType)) {
|
||||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -75,6 +75,10 @@ export const buildInputFieldValue = (
|
|||||||
if (template.type === 'model') {
|
if (template.type === 'model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'vae_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -5,6 +5,7 @@ import { Graph } from 'services/api/types';
|
|||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -31,6 +32,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field.type === 'vae_model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToVAEModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return field.value;
|
return field.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType } from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a main model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: VAEModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
Loading…
x
Reference in New Issue
Block a user