mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for T2I-Adapter in node workflows (#4612)
* Bump diffusers to 0.21.2. * Add T2IAdapterInvocation boilerplate. * Add T2I-Adapter model to model-management. * (minor) Tidy prepare_control_image(...). * Add logic to run the T2I-Adapter models at the start of the DenoiseLatentsInvocation. * Add logic for applying T2I-Adapter weights and accumulating. * Add T2IAdapter to MODEL_CLASSES map. * yarn typegen * Add model probes for T2I-Adapter models. * Add all of the frontend boilerplate required to use T2I-Adapter in the nodes editor. * Add T2IAdapterModel.convert_if_required(...). * Fix errors in T2I-Adapter input image sizing logic. * Fix bug with handling of multiple T2I-Adapters. * black / flake8 * Fix typo * yarn build * Add num_channels param to prepare_control_image(...). * Link to upstream diffusers bugfix PR that currently requires a workaround. * feat: Add Color Map Preprocessor Needed for the color T2I Adapter * feat: Add Color Map Preprocessor to Linear UI * Revert "feat: Add Color Map Preprocessor" This reverts commita1119a00bf
. * Revert "feat: Add Color Map Preprocessor to Linear UI" This reverts commitbd8a9b82d8
. * Fix T2I-Adapter field rendering in workflow editor. * yarn build, yarn typegen --------- Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
@ -16,6 +16,7 @@ import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
|
||||
import BoardInputField from './inputs/BoardInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
@ -188,6 +189,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'T2IAdapterModelField' &&
|
||||
fieldTemplate?.type === 'T2IAdapterModelField'
|
||||
) {
|
||||
return (
|
||||
<T2IAdapterModelInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
|
@ -0,0 +1,19 @@
|
||||
import {
|
||||
T2IAdapterInputFieldTemplate,
|
||||
T2IAdapterInputFieldValue,
|
||||
T2IAdapterPolymorphicInputFieldTemplate,
|
||||
T2IAdapterPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const T2IAdapterInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
T2IAdapterInputFieldValue | T2IAdapterPolymorphicInputFieldValue,
|
||||
T2IAdapterInputFieldTemplate | T2IAdapterPolymorphicInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterInputFieldComponent);
|
@ -0,0 +1,100 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const T2IAdapterModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
T2IAdapterModelInputFieldValue,
|
||||
T2IAdapterModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const t2iAdapterModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
t2iAdapterModels?.entities[
|
||||
`${t2iAdapterModel?.base_model}/t2i_adapter/${t2iAdapterModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
t2iAdapterModel?.base_model,
|
||||
t2iAdapterModel?.model_name,
|
||||
t2iAdapterModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!t2iAdapterModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(t2iAdapterModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [t2iAdapterModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newT2IAdapterModel = modelIdToT2IAdapterModelParam(v);
|
||||
|
||||
if (!newT2IAdapterModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldT2IAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newT2IAdapterModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
className="nowheel nodrag"
|
||||
tooltip={selectedModel?.description}
|
||||
value={selectedModel?.id ?? null}
|
||||
placeholder="Pick one"
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
sx={{ width: '100%' }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(T2IAdapterModelInputFieldComponent);
|
@ -55,6 +55,7 @@ import {
|
||||
SchedulerInputFieldValue,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
StringInputFieldValue,
|
||||
T2IAdapterModelInputFieldValue,
|
||||
VaeModelInputFieldValue,
|
||||
Workflow,
|
||||
} from '../types/types';
|
||||
@ -645,6 +646,12 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldT2IAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<T2IAdapterModelInputFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldEnumModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<EnumInputFieldValue>
|
||||
@ -1009,6 +1016,7 @@ export const {
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
|
@ -31,6 +31,7 @@ export const COLLECTION_TYPES: FieldType[] = [
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
'T2IAdapterCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
@ -43,6 +44,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
'T2IAdapterPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
@ -57,6 +59,7 @@ export const MODEL_TYPES: FieldType[] = [
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
@ -70,6 +73,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
T2IAdapterField: 'T2IAdapterCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
@ -87,6 +91,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
@ -99,6 +104,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
@ -123,6 +129,7 @@ export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
'Scheduler',
|
||||
'IPAdapterModelField',
|
||||
'BoardField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
@ -272,7 +279,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: t('nodes.integerPolymorphic'),
|
||||
},
|
||||
IPAdapterField: {
|
||||
color: 'green.300',
|
||||
color: 'teal.500',
|
||||
description: 'IP-Adapter info passed between nodes.',
|
||||
title: 'IP-Adapter',
|
||||
},
|
||||
@ -341,6 +348,26 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
description: t('nodes.stringPolymorphicDescription'),
|
||||
title: t('nodes.stringPolymorphic'),
|
||||
},
|
||||
T2IAdapterCollection: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.t2iAdapterCollectionDescription'),
|
||||
title: t('nodes.t2iAdapterCollection'),
|
||||
},
|
||||
T2IAdapterField: {
|
||||
color: 'teal.500',
|
||||
description: t('nodes.t2iAdapterFieldDescription'),
|
||||
title: t('nodes.t2iAdapterField'),
|
||||
},
|
||||
T2IAdapterModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'T2I-Adapter',
|
||||
},
|
||||
T2IAdapterPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'T2I-Adapter info passed between nodes.',
|
||||
title: 'T2I-Adapter Polymorphic',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: t('nodes.uNetFieldDescription'),
|
||||
|
@ -114,6 +114,10 @@ export const zFieldType = z.enum([
|
||||
'string',
|
||||
'StringCollection',
|
||||
'StringPolymorphic',
|
||||
'T2IAdapterCollection',
|
||||
'T2IAdapterField',
|
||||
'T2IAdapterModelField',
|
||||
'T2IAdapterPolymorphic',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'VaeModelField',
|
||||
@ -426,6 +430,48 @@ export type IPAdapterInputFieldValue = z.infer<
|
||||
typeof zIPAdapterInputFieldValue
|
||||
>;
|
||||
|
||||
export const zT2IAdapterModel = zModelIdentifier;
|
||||
export type T2IAdapterModel = z.infer<typeof zT2IAdapterModel>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zT2IAdapterModel,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
resize_mode: z
|
||||
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
|
||||
.optional(),
|
||||
});
|
||||
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
|
||||
|
||||
export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('T2IAdapterField'),
|
||||
value: zT2IAdapterField.optional(),
|
||||
});
|
||||
export type T2IAdapterInputFieldValue = z.infer<
|
||||
typeof zT2IAdapterInputFieldValue
|
||||
>;
|
||||
|
||||
export const zT2IAdapterPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('T2IAdapterPolymorphic'),
|
||||
value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(),
|
||||
});
|
||||
export type T2IAdapterPolymorphicInputFieldValue = z.infer<
|
||||
typeof zT2IAdapterPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend(
|
||||
{
|
||||
type: z.literal('T2IAdapterCollection'),
|
||||
value: z.array(zT2IAdapterField).optional(),
|
||||
}
|
||||
);
|
||||
export type T2IAdapterCollectionInputFieldValue = z.infer<
|
||||
typeof zT2IAdapterCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
@ -592,6 +638,17 @@ export type IPAdapterModelInputFieldValue = z.infer<
|
||||
typeof zIPAdapterModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zT2IAdapterModelField = zModelIdentifier;
|
||||
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||
|
||||
export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('T2IAdapterModelField'),
|
||||
value: zT2IAdapterModelField.optional(),
|
||||
});
|
||||
export type T2IAdapterModelInputFieldValue = z.infer<
|
||||
typeof zT2IAdapterModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Collection'),
|
||||
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
||||
@ -688,6 +745,10 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zStringCollectionInputFieldValue,
|
||||
zStringPolymorphicInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zT2IAdapterInputFieldValue,
|
||||
zT2IAdapterModelInputFieldValue,
|
||||
zT2IAdapterCollectionInputFieldValue,
|
||||
zT2IAdapterPolymorphicInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
@ -889,6 +950,24 @@ export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'IPAdapterField';
|
||||
};
|
||||
|
||||
export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'T2IAdapterField';
|
||||
};
|
||||
|
||||
export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'T2IAdapterCollection';
|
||||
item_default?: T2IAdapterField;
|
||||
};
|
||||
|
||||
export type T2IAdapterPolymorphicInputFieldTemplate = Omit<
|
||||
T2IAdapterInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'T2IAdapterPolymorphic';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'enum';
|
||||
@ -931,6 +1010,11 @@ export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'IPAdapterModelField';
|
||||
};
|
||||
|
||||
export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'T2IAdapterModelField';
|
||||
};
|
||||
|
||||
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'Collection';
|
||||
@ -1016,6 +1100,10 @@ export type InputFieldTemplate =
|
||||
| StringCollectionInputFieldTemplate
|
||||
| StringPolymorphicInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| T2IAdapterInputFieldTemplate
|
||||
| T2IAdapterCollectionInputFieldTemplate
|
||||
| T2IAdapterModelInputFieldTemplate
|
||||
| T2IAdapterPolymorphicInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate;
|
||||
|
@ -62,6 +62,11 @@ import {
|
||||
ConditioningField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterPolymorphicInputFieldTemplate,
|
||||
T2IAdapterCollectionInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
} from '../types/types';
|
||||
@ -452,6 +457,19 @@ const buildIPAdapterModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => {
|
||||
const template: T2IAdapterModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'T2IAdapterModelField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBoardInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -691,6 +709,46 @@ const buildIPAdapterInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => {
|
||||
const template: T2IAdapterInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'T2IAdapterField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => {
|
||||
const template: T2IAdapterPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'T2IAdapterPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildT2IAdapterCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => {
|
||||
const template: T2IAdapterCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'T2IAdapterCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -910,6 +968,10 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
string: buildStringInputFieldTemplate,
|
||||
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||
T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate,
|
||||
T2IAdapterField: buildT2IAdapterInputFieldTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate,
|
||||
UNetField: buildUNetInputFieldTemplate,
|
||||
VaeField: buildVaeInputFieldTemplate,
|
||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||
|
@ -45,6 +45,10 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
||||
string: '',
|
||||
StringCollection: [],
|
||||
StringPolymorphic: '',
|
||||
T2IAdapterCollection: [],
|
||||
T2IAdapterField: undefined,
|
||||
T2IAdapterModelField: undefined,
|
||||
T2IAdapterPolymorphic: undefined,
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
|
@ -340,6 +340,17 @@ export const zIPAdapterModel = z.object({
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||
/**
|
||||
* Zod schema for T2I-Adapter models
|
||||
*/
|
||||
export const zT2IAdapterModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type T2IAdapterModelParam = z.infer<typeof zT2IAdapterModel>;
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
*/
|
||||
|
@ -0,0 +1,29 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { T2IAdapterModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToT2IAdapterModelParam = (
|
||||
t2iAdapterModelId: string
|
||||
): T2IAdapterModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
|
||||
|
||||
const result = zT2IAdapterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
t2iAdapterModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse T2I-Adapter model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -6,6 +6,7 @@ import {
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
DiffusersModelConfig,
|
||||
ImportModelConfig,
|
||||
LoRAModelConfig,
|
||||
@ -41,6 +42,10 @@ export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
@ -53,6 +58,7 @@ type AnyModelConfigEntity =
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| IPAdapterModelConfigEntity
|
||||
| T2IAdapterModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
|
||||
@ -145,6 +151,10 @@ export const ipAdapterModelsAdapter =
|
||||
createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const t2iAdapterModelsAdapter =
|
||||
createEntityAdapter<T2IAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const textualInversionModelsAdapter =
|
||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
@ -470,6 +480,37 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
getT2IAdapterModels: build.query<
|
||||
EntityState<T2IAdapterModelConfigEntity>,
|
||||
void
|
||||
>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = [
|
||||
{ type: 'T2IAdapterModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'T2IAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<T2IAdapterModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return t2iAdapterModelsAdapter.setAll(
|
||||
t2iAdapterModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||
providesTags: (result) => {
|
||||
@ -567,6 +608,7 @@ export const {
|
||||
useGetOnnxModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
|
594
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
594
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -67,6 +67,7 @@ export type VAEModelField = s['VAEModelField'];
|
||||
export type LoRAModelField = s['LoRAModelField'];
|
||||
export type ControlNetModelField = s['ControlNetModelField'];
|
||||
export type IPAdapterModelField = s['IPAdapterModelField'];
|
||||
export type T2IAdapterModelField = s['T2IAdapterModelField'];
|
||||
export type ModelsList = s['ModelsList'];
|
||||
export type ControlField = s['ControlField'];
|
||||
export type IPAdapterField = s['IPAdapterField'];
|
||||
@ -83,6 +84,9 @@ export type ControlNetModelConfig =
|
||||
| ControlNetModelDiffusersConfig;
|
||||
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
|
||||
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
|
||||
export type T2IAdapterModelDiffusersConfig =
|
||||
s['T2IAdapterModelDiffusersConfig'];
|
||||
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
|
||||
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
||||
export type DiffusersModelConfig =
|
||||
| s['StableDiffusion1ModelDiffusersConfig']
|
||||
@ -99,6 +103,7 @@ export type AnyModelConfig =
|
||||
| VaeModelConfig
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig
|
||||
| OnnxModelConfig;
|
||||
|
Reference in New Issue
Block a user