mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Addvae_model
input type front end
This commit is contained in:
parent
5ad6b64721
commit
38660a2162
@ -3,20 +3,21 @@ import { memo } from 'react';
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -0,0 +1,104 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/VAESelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const VaeModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
VaeModelInputFieldValue,
|
||||
VaeModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: vaeModels } = useListModelsQuery({
|
||||
model_type: 'vae',
|
||||
});
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'auto',
|
||||
label: 'auto',
|
||||
group: 'Automatic',
|
||||
},
|
||||
];
|
||||
|
||||
forEach(vaeModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [vaeModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
handleValueChanged('none');
|
||||
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VaeModelInputFieldComponent);
|
@ -64,6 +64,7 @@ export type FieldType =
|
||||
| 'vae'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -91,6 +92,7 @@ export type InputFieldValue =
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -116,6 +118,7 @@ export type InputFieldTemplate =
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -337,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'vae_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -24,6 +24,7 @@ import {
|
||||
TypeHints,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVaeModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
|
||||
const template: VaeModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'vae_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -75,6 +75,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -5,6 +5,7 @@ import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
@ -31,6 +32,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'vae_model') {
|
||||
if (field.value) {
|
||||
return modelIdToVAEModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
return field.value;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,16 @@
|
||||
import { BaseModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: VAEModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
Loading…
Reference in New Issue
Block a user