feat: Addvae_model input type front end

This commit is contained in:
blessedcoolant 2023-06-30 06:27:36 +12:00 committed by psychedelicious
parent 5ad6b64721
commit 38660a2162
7 changed files with 183 additions and 9 deletions

View File

@ -3,20 +3,21 @@ import { memo } from 'react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent

View File

@ -0,0 +1,104 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/VAESelect';
import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const VaeModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({
model_type: 'vae',
});
const selectedModel = useMemo(
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
[vaeModels?.entities, vaeModels?.ids, field.value]
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [
{
value: 'auto',
label: 'auto',
group: 'Automatic',
},
];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && vaeModels?.ids.includes(field.value)) {
return;
}
handleValueChanged('none');
}, [field.value, handleValueChanged, vaeModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(VaeModelInputFieldComponent);

View File

@ -64,6 +64,7 @@ export type FieldType =
| 'vae'
| 'control'
| 'model'
| 'vae_model'
| 'array'
| 'item'
| 'color'
@ -91,6 +92,7 @@ export type InputFieldValue =
| ControlInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| VaeModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@ -116,6 +118,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
value?: string;
};
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@ -337,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'vae_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';

View File

@ -24,6 +24,7 @@ import {
TypeHints,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
} from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
return template;
};
const buildVaeModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
const template: VaeModelInputFieldTemplate = {
...baseField,
type: 'vae_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -75,6 +75,10 @@ export const buildInputFieldValue = (
if (template.type === 'model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}
}
return fieldValue;

View File

@ -5,6 +5,7 @@ import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/**
* We need to do special handling for some fields
@ -31,6 +32,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
}
}
if (field.type === 'vae_model') {
if (field.value) {
return modelIdToVAEModelField(field.value);
}
}
return field.value;
};

View File

@ -0,0 +1,16 @@
import { BaseModelType } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: VAEModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};