Add support for T2I-Adapter in node workflows (#4612)

* Bump diffusers to 0.21.2.

* Add T2IAdapterInvocation boilerplate.

* Add T2I-Adapter model to model-management.

* (minor) Tidy prepare_control_image(...).

* Add logic to run the T2I-Adapter models at the start of the DenoiseLatentsInvocation.

* Add logic for applying T2I-Adapter weights and accumulating.

* Add T2IAdapter to MODEL_CLASSES map.

* yarn typegen

* Add model probes for T2I-Adapter models.

* Add all of the frontend boilerplate required to use T2I-Adapter in the nodes editor.

* Add T2IAdapterModel.convert_if_required(...).

* Fix errors in T2I-Adapter input image sizing logic.

* Fix bug with handling of multiple T2I-Adapters.

* black / flake8

* Fix typo

* yarn build

* Add num_channels param to prepare_control_image(...).

* Link to upstream diffusers bugfix PR that currently requires a workaround.

* feat: Add Color Map Preprocessor

Needed for the color T2I Adapter

* feat: Add Color Map Preprocessor to Linear UI

* Revert "feat: Add Color Map Preprocessor"

This reverts commit a1119a00bf.

* Revert "feat: Add Color Map Preprocessor to Linear UI"

This reverts commit bd8a9b82d8.

* Fix T2I-Adapter field rendering in workflow editor.

* yarn build, yarn typegen

---------

Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Ryan Dick
2023-10-05 01:29:16 -04:00
committed by GitHub
parent fbe6452c45
commit 78377469db
32 changed files with 1610 additions and 248 deletions

View File

@ -16,6 +16,7 @@ import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
import BoardInputField from './inputs/BoardInputField';
type InputFieldProps = {
@ -188,6 +189,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'T2IAdapterModelField' &&
fieldTemplate?.type === 'T2IAdapterModelField'
) {
return (
<T2IAdapterModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField

View File

@ -0,0 +1,19 @@
import {
T2IAdapterInputFieldTemplate,
T2IAdapterInputFieldValue,
T2IAdapterPolymorphicInputFieldTemplate,
T2IAdapterPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const T2IAdapterInputFieldComponent = (
_props: FieldComponentProps<
T2IAdapterInputFieldValue | T2IAdapterPolymorphicInputFieldValue,
T2IAdapterInputFieldTemplate | T2IAdapterPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(T2IAdapterInputFieldComponent);

View File

@ -0,0 +1,100 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
T2IAdapterModelInputFieldTemplate,
T2IAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
const T2IAdapterModelInputFieldComponent = (
props: FieldComponentProps<
T2IAdapterModelInputFieldValue,
T2IAdapterModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const t2iAdapterModel = field.value;
const dispatch = useAppDispatch();
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
t2iAdapterModels?.entities[
`${t2iAdapterModel?.base_model}/t2i_adapter/${t2iAdapterModel?.model_name}`
] ?? null,
[
t2iAdapterModel?.base_model,
t2iAdapterModel?.model_name,
t2iAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!t2iAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(t2iAdapterModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [t2iAdapterModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newT2IAdapterModel = modelIdToT2IAdapterModelParam(v);
if (!newT2IAdapterModel) {
return;
}
dispatch(
fieldT2IAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value: newT2IAdapterModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(T2IAdapterModelInputFieldComponent);

View File

@ -55,6 +55,7 @@ import {
SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue,
StringInputFieldValue,
T2IAdapterModelInputFieldValue,
VaeModelInputFieldValue,
Workflow,
} from '../types/types';
@ -645,6 +646,12 @@ const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action);
},
fieldT2IAdapterModelValueChanged: (
state,
action: FieldValueAction<T2IAdapterModelInputFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: (
state,
action: FieldValueAction<EnumInputFieldValue>
@ -1009,6 +1016,7 @@ export const {
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,

View File

@ -31,6 +31,7 @@ export const COLLECTION_TYPES: FieldType[] = [
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
'T2IAdapterCollection',
];
export const POLYMORPHIC_TYPES: FieldType[] = [
@ -43,6 +44,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
'T2IAdapterPolymorphic',
];
export const MODEL_TYPES: FieldType[] = [
@ -57,6 +59,7 @@ export const MODEL_TYPES: FieldType[] = [
'UNetField',
'VaeField',
'ClipField',
'T2IAdapterModelField',
];
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
@ -70,6 +73,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection',
ColorField: 'ColorCollection',
T2IAdapterField: 'T2IAdapterCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
@ -87,6 +91,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic',
T2IAdapterField: 'T2IAdapterPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
@ -99,6 +104,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField',
T2IAdapterPolymorphic: 'T2IAdapterField',
};
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
@ -123,6 +129,7 @@ export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
'Scheduler',
'IPAdapterModelField',
'BoardField',
'T2IAdapterModelField',
];
export const isPolymorphicItemType = (
@ -272,7 +279,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: t('nodes.integerPolymorphic'),
},
IPAdapterField: {
color: 'green.300',
color: 'teal.500',
description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter',
},
@ -341,6 +348,26 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: t('nodes.stringPolymorphicDescription'),
title: t('nodes.stringPolymorphic'),
},
T2IAdapterCollection: {
color: 'teal.500',
description: t('nodes.t2iAdapterCollectionDescription'),
title: t('nodes.t2iAdapterCollection'),
},
T2IAdapterField: {
color: 'teal.500',
description: t('nodes.t2iAdapterFieldDescription'),
title: t('nodes.t2iAdapterField'),
},
T2IAdapterModelField: {
color: 'teal.500',
description: 'TODO',
title: 'T2I-Adapter',
},
T2IAdapterPolymorphic: {
color: 'teal.500',
description: 'T2I-Adapter info passed between nodes.',
title: 'T2I-Adapter Polymorphic',
},
UNetField: {
color: 'red.500',
description: t('nodes.uNetFieldDescription'),

View File

@ -114,6 +114,10 @@ export const zFieldType = z.enum([
'string',
'StringCollection',
'StringPolymorphic',
'T2IAdapterCollection',
'T2IAdapterField',
'T2IAdapterModelField',
'T2IAdapterPolymorphic',
'UNetField',
'VaeField',
'VaeModelField',
@ -426,6 +430,48 @@ export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue
>;
export const zT2IAdapterModel = zModelIdentifier;
export type T2IAdapterModel = z.infer<typeof zT2IAdapterModel>;
export const zT2IAdapterField = z.object({
image: zImageField,
t2i_adapter_model: zT2IAdapterModel,
weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
resize_mode: z
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
.optional(),
});
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('T2IAdapterField'),
value: zT2IAdapterField.optional(),
});
export type T2IAdapterInputFieldValue = z.infer<
typeof zT2IAdapterInputFieldValue
>;
export const zT2IAdapterPolymorphicInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('T2IAdapterPolymorphic'),
value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(),
});
export type T2IAdapterPolymorphicInputFieldValue = z.infer<
typeof zT2IAdapterPolymorphicInputFieldValue
>;
export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend(
{
type: z.literal('T2IAdapterCollection'),
value: z.array(zT2IAdapterField).optional(),
}
);
export type T2IAdapterCollectionInputFieldValue = z.infer<
typeof zT2IAdapterCollectionInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
'main',
@ -592,6 +638,17 @@ export type IPAdapterModelInputFieldValue = z.infer<
typeof zIPAdapterModelInputFieldValue
>;
export const zT2IAdapterModelField = zModelIdentifier;
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('T2IAdapterModelField'),
value: zT2IAdapterModelField.optional(),
});
export type T2IAdapterModelInputFieldValue = z.infer<
typeof zT2IAdapterModelInputFieldValue
>;
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Collection'),
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
@ -688,6 +745,10 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,
zT2IAdapterInputFieldValue,
zT2IAdapterModelInputFieldValue,
zT2IAdapterCollectionInputFieldValue,
zT2IAdapterPolymorphicInputFieldValue,
zUNetInputFieldValue,
zVaeInputFieldValue,
zVaeModelInputFieldValue,
@ -889,6 +950,24 @@ export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
type: 'IPAdapterField';
};
export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'T2IAdapterField';
};
export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'T2IAdapterCollection';
item_default?: T2IAdapterField;
};
export type T2IAdapterPolymorphicInputFieldTemplate = Omit<
T2IAdapterInputFieldTemplate,
'type'
> & {
type: 'T2IAdapterPolymorphic';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'enum';
@ -931,6 +1010,11 @@ export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'IPAdapterModelField';
};
export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'T2IAdapterModelField';
};
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'Collection';
@ -1016,6 +1100,10 @@ export type InputFieldTemplate =
| StringCollectionInputFieldTemplate
| StringPolymorphicInputFieldTemplate
| StringInputFieldTemplate
| T2IAdapterInputFieldTemplate
| T2IAdapterCollectionInputFieldTemplate
| T2IAdapterModelInputFieldTemplate
| T2IAdapterPolymorphicInputFieldTemplate
| UNetInputFieldTemplate
| VaeInputFieldTemplate
| VaeModelInputFieldTemplate;

View File

@ -62,6 +62,11 @@ import {
ConditioningField,
IPAdapterInputFieldTemplate,
IPAdapterModelInputFieldTemplate,
T2IAdapterField,
T2IAdapterInputFieldTemplate,
T2IAdapterModelInputFieldTemplate,
T2IAdapterPolymorphicInputFieldTemplate,
T2IAdapterCollectionInputFieldTemplate,
BoardInputFieldTemplate,
InputFieldTemplate,
} from '../types/types';
@ -452,6 +457,19 @@ const buildIPAdapterModelInputFieldTemplate = ({
return template;
};
const buildT2IAdapterModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => {
const template: T2IAdapterModelInputFieldTemplate = {
...baseField,
type: 'T2IAdapterModelField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardInputFieldTemplate = ({
schemaObject,
baseField,
@ -691,6 +709,46 @@ const buildIPAdapterInputFieldTemplate = ({
return template;
};
const buildT2IAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => {
const template: T2IAdapterInputFieldTemplate = {
...baseField,
type: 'T2IAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => {
const template: T2IAdapterPolymorphicInputFieldTemplate = {
...baseField,
type: 'T2IAdapterPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => {
const template: T2IAdapterCollectionInputFieldTemplate = {
...baseField,
type: 'T2IAdapterCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@ -910,6 +968,10 @@ const TEMPLATE_BUILDER_MAP: {
string: buildStringInputFieldTemplate,
StringCollection: buildStringCollectionInputFieldTemplate,
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate,
T2IAdapterField: buildT2IAdapterInputFieldTemplate,
T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate,
T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate,
UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate,

View File

@ -45,6 +45,10 @@ const FIELD_VALUE_FALLBACK_MAP: {
string: '',
StringCollection: [],
StringPolymorphic: '',
T2IAdapterCollection: [],
T2IAdapterField: undefined,
T2IAdapterModelField: undefined,
T2IAdapterPolymorphic: undefined,
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,

View File

@ -340,6 +340,17 @@ export const zIPAdapterModel = z.object({
* Type alias for model parameter, inferred from its zod schema
*/
export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
/**
* Zod schema for T2I-Adapter models
*/
export const zT2IAdapterModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type T2IAdapterModelParam = z.infer<typeof zT2IAdapterModel>;
/**
* Zod schema for l2l strength parameter
*/

View File

@ -0,0 +1,29 @@
import { logger } from 'app/logging/logger';
import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
import { T2IAdapterModelField } from 'services/api/types';
export const modelIdToT2IAdapterModelParam = (
t2iAdapterModelId: string
): T2IAdapterModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
const result = zT2IAdapterModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
t2iAdapterModelId,
errors: result.error.format(),
},
'Failed to parse T2I-Adapter model id'
);
return;
}
return result.data;
};

View File

@ -6,6 +6,7 @@ import {
CheckpointModelConfig,
ControlNetModelConfig,
IPAdapterModelConfig,
T2IAdapterModelConfig,
DiffusersModelConfig,
ImportModelConfig,
LoRAModelConfig,
@ -41,6 +42,10 @@ export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string;
};
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
@ -53,6 +58,7 @@ type AnyModelConfigEntity =
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| IPAdapterModelConfigEntity
| T2IAdapterModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
@ -145,6 +151,10 @@ export const ipAdapterModelsAdapter =
createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const t2iAdapterModelsAdapter =
createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
@ -470,6 +480,37 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getT2IAdapterModels: build.query<
EntityState<T2IAdapterModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'T2IAdapterModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'T2IAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
const entities = createModelEntities<T2IAdapterModelConfigEntity>(
response.models
);
return t2iAdapterModelsAdapter.setAll(
t2iAdapterModelsAdapter.getInitialState(),
entities
);
},
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => {
@ -567,6 +608,7 @@ export const {
useGetOnnxModelsQuery,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,

File diff suppressed because one or more lines are too long

View File

@ -67,6 +67,7 @@ export type VAEModelField = s['VAEModelField'];
export type LoRAModelField = s['LoRAModelField'];
export type ControlNetModelField = s['ControlNetModelField'];
export type IPAdapterModelField = s['IPAdapterModelField'];
export type T2IAdapterModelField = s['T2IAdapterModelField'];
export type ModelsList = s['ModelsList'];
export type ControlField = s['ControlField'];
export type IPAdapterField = s['IPAdapterField'];
@ -83,6 +84,9 @@ export type ControlNetModelConfig =
| ControlNetModelDiffusersConfig;
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
export type T2IAdapterModelDiffusersConfig =
s['T2IAdapterModelDiffusersConfig'];
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
export type DiffusersModelConfig =
| s['StableDiffusion1ModelDiffusersConfig']
@ -99,6 +103,7 @@ export type AnyModelConfig =
| VaeModelConfig
| ControlNetModelConfig
| IPAdapterModelConfig
| T2IAdapterModelConfig
| TextualInversionModelConfig
| MainModelConfig
| OnnxModelConfig;