mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Re-add feat/mix cnet t2iadapter (#4929)
Reverts invoke-ai/InvokeAI#4923, which was a revert on the premature merge. slide to the left. revert, revert.
This commit is contained in:
commit
cb6d0c8851
@ -546,11 +546,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
if control_data is not None and t2i_adapter_data is not None:
|
||||
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||
elif control_data is not None:
|
||||
down_intrablock_additional_residuals = None
|
||||
# if control_data is not None and t2i_adapter_data is not None:
|
||||
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||
# elif control_data is not None:
|
||||
if control_data is not None:
|
||||
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
||||
control_data=control_data,
|
||||
sample=latent_model_input,
|
||||
@ -559,7 +561,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
)
|
||||
elif t2i_adapter_data is not None:
|
||||
# elif t2i_adapter_data is not None:
|
||||
if t2i_adapter_data is not None:
|
||||
accum_adapter_state = None
|
||||
for single_t2i_adapter_data in t2i_adapter_data:
|
||||
# Determine the T2I-Adapter weights for the current denoising step.
|
||||
@ -584,7 +587,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||
|
||||
down_block_additional_residuals = accum_adapter_state
|
||||
# down_block_additional_residuals = accum_adapter_state
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
@ -593,8 +597,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
# extra:
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
)
|
||||
|
||||
guidance_scale = conditioning_data.guidance_scale
|
||||
|
@ -260,7 +260,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
@ -410,6 +409,15 @@ class InvokeAIDiffuserComponent:
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
||||
uncond_down_intrablock.append(_uncond_down)
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
@ -441,6 +449,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
@ -470,6 +479,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
@ -494,6 +504,15 @@ class InvokeAIDiffuserComponent:
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
||||
uncond_down_intrablock.append(_uncond_down)
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
@ -522,6 +541,7 @@ class InvokeAIDiffuserComponent:
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
@ -541,6 +561,7 @@ class InvokeAIDiffuserComponent:
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ import { addFirstListImagesListener } from './listeners/addFirstListImagesListen
|
||||
import { addAnyEnqueuedListener } from './listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from './listeners/appStarted';
|
||||
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
@ -71,8 +72,6 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
|
||||
import { addControlAdapterAddedOrEnabledListener } from './listeners/controlAdapterAddedOrEnabled';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -200,7 +199,3 @@ addTabChangedListener();
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener();
|
||||
|
||||
// Display toast when controlnet or t2i adapter enabled
|
||||
// TODO: Remove when they can both be enabled at same time
|
||||
addControlAdapterAddedOrEnabledListener();
|
||||
|
@ -1,87 +0,0 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import {
|
||||
controlAdapterAdded,
|
||||
controlAdapterAddedFromImage,
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterRecalled,
|
||||
selectControlAdapterAll,
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import i18n from 'i18n';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const isAnyControlAdapterAddedOrEnabled = isAnyOf(
|
||||
controlAdapterAdded,
|
||||
controlAdapterAddedFromImage,
|
||||
controlAdapterRecalled,
|
||||
controlAdapterIsEnabledChanged
|
||||
);
|
||||
|
||||
/**
|
||||
* Until we can have both controlnet and t2i adapter enabled at once, they are mutually exclusive
|
||||
* This displays a toast when one is enabled and the other is already enabled, or one is added
|
||||
* with the other enabled
|
||||
*/
|
||||
export const addControlAdapterAddedOrEnabledListener = () => {
|
||||
startAppListening({
|
||||
matcher: isAnyControlAdapterAddedOrEnabled,
|
||||
effect: async (action, { dispatch, getOriginalState }) => {
|
||||
const controlAdapters = getOriginalState().controlAdapters;
|
||||
|
||||
const hasEnabledControlNets = selectControlAdapterAll(
|
||||
controlAdapters
|
||||
).some((ca) => ca.isEnabled && ca.type === 'controlnet');
|
||||
|
||||
const hasEnabledT2IAdapters = selectControlAdapterAll(
|
||||
controlAdapters
|
||||
).some((ca) => ca.isEnabled && ca.type === 't2i_adapter');
|
||||
|
||||
let caType: ControlAdapterType | null = null;
|
||||
|
||||
if (controlAdapterAdded.match(action)) {
|
||||
caType = action.payload.type;
|
||||
}
|
||||
|
||||
if (controlAdapterAddedFromImage.match(action)) {
|
||||
caType = action.payload.type;
|
||||
}
|
||||
|
||||
if (controlAdapterRecalled.match(action)) {
|
||||
caType = action.payload.type;
|
||||
}
|
||||
|
||||
if (controlAdapterIsEnabledChanged.match(action)) {
|
||||
const _caType = selectControlAdapterById(
|
||||
controlAdapters,
|
||||
action.payload.id
|
||||
)?.type;
|
||||
if (!_caType) {
|
||||
return;
|
||||
}
|
||||
caType = _caType;
|
||||
}
|
||||
|
||||
if (
|
||||
(caType === 'controlnet' && hasEnabledT2IAdapters) ||
|
||||
(caType === 't2i_adapter' && hasEnabledControlNets)
|
||||
) {
|
||||
const title =
|
||||
caType === 'controlnet'
|
||||
? i18n.t('controlnet.controlNetEnabledT2IDisabled')
|
||||
: i18n.t('controlnet.t2iEnabledControlNetDisabled');
|
||||
|
||||
const description = i18n.t('controlnet.controlNetT2IMutexDesc');
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title,
|
||||
description,
|
||||
status: 'warning',
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -88,61 +88,6 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
|
||||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
|
||||
);
|
||||
|
||||
// TODO: I think we can safely remove this?
|
||||
// const disableAllIPAdapters = (
|
||||
// state: ControlAdaptersState,
|
||||
// exclude?: string
|
||||
// ) => {
|
||||
// const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
|
||||
// .filter((ca) => ca.id !== exclude)
|
||||
// .map((ca) => ({
|
||||
// id: ca.id,
|
||||
// changes: { isEnabled: false },
|
||||
// }));
|
||||
// caAdapter.updateMany(state, updates);
|
||||
// };
|
||||
|
||||
const disableAllControlNets = (
|
||||
state: ControlAdaptersState,
|
||||
exclude?: string
|
||||
) => {
|
||||
const updates: Update<ControlAdapterConfig>[] = selectAllControlNets(state)
|
||||
.filter((ca) => ca.id !== exclude)
|
||||
.map((ca) => ({
|
||||
id: ca.id,
|
||||
changes: { isEnabled: false },
|
||||
}));
|
||||
caAdapter.updateMany(state, updates);
|
||||
};
|
||||
|
||||
const disableAllT2IAdapters = (
|
||||
state: ControlAdaptersState,
|
||||
exclude?: string
|
||||
) => {
|
||||
const updates: Update<ControlAdapterConfig>[] = selectAllT2IAdapters(state)
|
||||
.filter((ca) => ca.id !== exclude)
|
||||
.map((ca) => ({
|
||||
id: ca.id,
|
||||
changes: { isEnabled: false },
|
||||
}));
|
||||
caAdapter.updateMany(state, updates);
|
||||
};
|
||||
|
||||
const disableIncompatibleControlAdapters = (
|
||||
state: ControlAdaptersState,
|
||||
type: ControlAdapterType,
|
||||
exclude?: string
|
||||
) => {
|
||||
if (type === 'controlnet') {
|
||||
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
|
||||
disableAllT2IAdapters(state, exclude);
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
// we cannot do controlnet + t2i adapter, if we are enabled a t2i, disable controlnets
|
||||
disableAllControlNets(state, exclude);
|
||||
}
|
||||
};
|
||||
|
||||
export const controlAdaptersSlice = createSlice({
|
||||
name: 'controlAdapters',
|
||||
initialState: initialControlAdapterState,
|
||||
@ -158,7 +103,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
) => {
|
||||
const { id, type, overrides } = action.payload;
|
||||
caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
|
||||
disableIncompatibleControlAdapters(state, type, id);
|
||||
},
|
||||
prepare: ({
|
||||
type,
|
||||
@ -175,8 +119,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
action: PayloadAction<ControlAdapterConfig>
|
||||
) => {
|
||||
caAdapter.addOne(state, action.payload);
|
||||
const { type, id } = action.payload;
|
||||
disableIncompatibleControlAdapters(state, type, id);
|
||||
},
|
||||
controlAdapterDuplicated: {
|
||||
reducer: (
|
||||
@ -196,8 +138,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
isEnabled: true,
|
||||
});
|
||||
caAdapter.addOne(state, newControlAdapter);
|
||||
const { type } = newControlAdapter;
|
||||
disableIncompatibleControlAdapters(state, type, newId);
|
||||
},
|
||||
prepare: (id: string) => {
|
||||
return { payload: { id, newId: uuidv4() } };
|
||||
@ -217,7 +157,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
state,
|
||||
buildControlAdapter(id, type, { controlImage })
|
||||
);
|
||||
disableIncompatibleControlAdapters(state, type, id);
|
||||
},
|
||||
prepare: (payload: {
|
||||
type: ControlAdapterType;
|
||||
@ -235,12 +174,6 @@ export const controlAdaptersSlice = createSlice({
|
||||
) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { isEnabled } });
|
||||
if (isEnabled) {
|
||||
// we are enabling a control adapter. due to limitations in the current system, we may need to disable other adapters
|
||||
// TODO: disable when multiple IP adapters are supported
|
||||
const ca = selectControlAdapterById(state, id);
|
||||
ca && disableIncompatibleControlAdapters(state, ca.type, id);
|
||||
}
|
||||
},
|
||||
controlAdapterImageChanged: (
|
||||
state,
|
||||
|
@ -8808,11 +8808,11 @@ export type components = {
|
||||
ui_order: number | null;
|
||||
};
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* IPAdapterModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
IPAdapterModelFormat: "invokeai";
|
||||
/**
|
||||
* IPAdapterModelFormat
|
||||
* @description An enumeration.
|
||||
@ -8832,11 +8832,11 @@ export type components = {
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* CLIPVisionModelFormat
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
CLIPVisionModelFormat: "diffusers";
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
|
Loading…
Reference in New Issue
Block a user