Merge branch 'main' into feat/nodes/freeu

This commit is contained in:
Kent Keirsey 2023-11-06 09:04:54 -08:00 committed by GitHub
commit ff8a8a1963
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 45 additions and 178 deletions

View File

@ -182,8 +182,8 @@ class IntegerMathInvocation(BaseInvocation):
operation: INTEGER_OPERATIONS = InputField( operation: INTEGER_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS
) )
a: int = InputField(default=0, description=FieldDescriptions.num_1) a: int = InputField(default=1, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2) b: int = InputField(default=1, description=FieldDescriptions.num_2)
@field_validator("b") @field_validator("b")
def no_unrepresentable_results(cls, v: int, info: ValidationInfo): def no_unrepresentable_results(cls, v: int, info: ValidationInfo):
@ -256,8 +256,8 @@ class FloatMathInvocation(BaseInvocation):
operation: FLOAT_OPERATIONS = InputField( operation: FLOAT_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS
) )
a: float = InputField(default=0, description=FieldDescriptions.num_1) a: float = InputField(default=1, description=FieldDescriptions.num_1)
b: float = InputField(default=0, description=FieldDescriptions.num_2) b: float = InputField(default=1, description=FieldDescriptions.num_2)
@field_validator("b") @field_validator("b")
def no_unrepresentable_results(cls, v: float, info: ValidationInfo): def no_unrepresentable_results(cls, v: float, info: ValidationInfo):

View File

@ -546,11 +546,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Handle ControlNet(s) and T2I-Adapter(s) # Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None down_block_additional_residuals = None
mid_block_additional_residual = None mid_block_additional_residual = None
if control_data is not None and t2i_adapter_data is not None: down_intrablock_additional_residuals = None
# if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility # TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers. # between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).") # raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
elif control_data is not None: # elif control_data is not None:
if control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step( down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data, control_data=control_data,
sample=latent_model_input, sample=latent_model_input,
@ -559,7 +561,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
) )
elif t2i_adapter_data is not None: # elif t2i_adapter_data is not None:
if t2i_adapter_data is not None:
accum_adapter_state = None accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data: for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step. # Determine the T2I-Adapter weights for the current denoising step.
@ -584,7 +587,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for idx, value in enumerate(single_t2i_adapter_data.adapter_state): for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight accum_adapter_state[idx] += value * t2i_adapter_weight
down_block_additional_residuals = accum_adapter_state # down_block_additional_residuals = accum_adapter_state
down_intrablock_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input, sample=latent_model_input,
@ -593,8 +597,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
# extra: # extra:
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
) )
guidance_scale = conditioning_data.guidance_scale guidance_scale = conditioning_data.guidance_scale

View File

@ -260,7 +260,6 @@ class InvokeAIDiffuserComponent:
conditioning_data, conditioning_data,
**kwargs, **kwargs,
) )
else: else:
( (
unconditioned_next_x, unconditioned_next_x,
@ -410,6 +409,15 @@ class InvokeAIDiffuserComponent:
uncond_down_block.append(_uncond_down) uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
@ -441,6 +449,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -470,6 +479,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -494,6 +504,15 @@ class InvokeAIDiffuserComponent:
uncond_down_block.append(_uncond_down) uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down) cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None uncond_mid_block, cond_mid_block = None, None
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
@ -522,6 +541,7 @@ class InvokeAIDiffuserComponent:
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )
@ -541,6 +561,7 @@ class InvokeAIDiffuserComponent:
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
**kwargs, **kwargs,
) )

View File

@ -12,6 +12,7 @@ import { addFirstListImagesListener } from './listeners/addFirstListImagesListen
import { addAnyEnqueuedListener } from './listeners/anyEnqueued'; import { addAnyEnqueuedListener } from './listeners/anyEnqueued';
import { addAppConfigReceivedListener } from './listeners/appConfigReceived'; import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
import { addAppStartedListener } from './listeners/appStarted'; import { addAppStartedListener } from './listeners/appStarted';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted'; import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
@ -71,8 +72,6 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addTabChangedListener } from './listeners/tabChanged'; import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
import { addControlAdapterAddedOrEnabledListener } from './listeners/controlAdapterAddedOrEnabled';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -200,7 +199,3 @@ addTabChangedListener();
// Dynamic prompts // Dynamic prompts
addDynamicPromptsListener(); addDynamicPromptsListener();
// Display toast when controlnet or t2i adapter enabled
// TODO: Remove when they can both be enabled at same time
addControlAdapterAddedOrEnabledListener();

View File

@ -1,87 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import {
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterIsEnabledChanged,
controlAdapterRecalled,
selectControlAdapterAll,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlAdapterType } from 'features/controlAdapters/store/types';
import { addToast } from 'features/system/store/systemSlice';
import i18n from 'i18n';
import { startAppListening } from '..';
const isAnyControlAdapterAddedOrEnabled = isAnyOf(
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterRecalled,
controlAdapterIsEnabledChanged
);
/**
* Until we can have both controlnet and t2i adapter enabled at once, they are mutually exclusive
* This displays a toast when one is enabled and the other is already enabled, or one is added
* with the other enabled
*/
export const addControlAdapterAddedOrEnabledListener = () => {
startAppListening({
matcher: isAnyControlAdapterAddedOrEnabled,
effect: async (action, { dispatch, getOriginalState }) => {
const controlAdapters = getOriginalState().controlAdapters;
const hasEnabledControlNets = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 'controlnet');
const hasEnabledT2IAdapters = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 't2i_adapter');
let caType: ControlAdapterType | null = null;
if (controlAdapterAdded.match(action)) {
caType = action.payload.type;
}
if (controlAdapterAddedFromImage.match(action)) {
caType = action.payload.type;
}
if (controlAdapterRecalled.match(action)) {
caType = action.payload.type;
}
if (controlAdapterIsEnabledChanged.match(action)) {
const _caType = selectControlAdapterById(
controlAdapters,
action.payload.id
)?.type;
if (!_caType) {
return;
}
caType = _caType;
}
if (
(caType === 'controlnet' && hasEnabledT2IAdapters) ||
(caType === 't2i_adapter' && hasEnabledControlNets)
) {
const title =
caType === 'controlnet'
? i18n.t('controlnet.controlNetEnabledT2IDisabled')
: i18n.t('controlnet.t2iEnabledControlNetDisabled');
const description = i18n.t('controlnet.controlNetT2IMutexDesc');
dispatch(
addToast({
title,
description,
status: 'warning',
})
);
}
},
});
};

View File

@ -88,61 +88,6 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
(ca.processorType === 'none' && Boolean(ca.controlImage))) (ca.processorType === 'none' && Boolean(ca.controlImage)))
); );
// TODO: I think we can safely remove this?
// const disableAllIPAdapters = (
// state: ControlAdaptersState,
// exclude?: string
// ) => {
// const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
// .filter((ca) => ca.id !== exclude)
// .map((ca) => ({
// id: ca.id,
// changes: { isEnabled: false },
// }));
// caAdapter.updateMany(state, updates);
// };
const disableAllControlNets = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllControlNets(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableAllT2IAdapters = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllT2IAdapters(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableIncompatibleControlAdapters = (
state: ControlAdaptersState,
type: ControlAdapterType,
exclude?: string
) => {
if (type === 'controlnet') {
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
disableAllT2IAdapters(state, exclude);
}
if (type === 't2i_adapter') {
// we cannot do controlnet + t2i adapter, if we are enabled a t2i, disable controlnets
disableAllControlNets(state, exclude);
}
};
export const controlAdaptersSlice = createSlice({ export const controlAdaptersSlice = createSlice({
name: 'controlAdapters', name: 'controlAdapters',
initialState: initialControlAdapterState, initialState: initialControlAdapterState,
@ -158,7 +103,6 @@ export const controlAdaptersSlice = createSlice({
) => { ) => {
const { id, type, overrides } = action.payload; const { id, type, overrides } = action.payload;
caAdapter.addOne(state, buildControlAdapter(id, type, overrides)); caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
disableIncompatibleControlAdapters(state, type, id);
}, },
prepare: ({ prepare: ({
type, type,
@ -175,8 +119,6 @@ export const controlAdaptersSlice = createSlice({
action: PayloadAction<ControlAdapterConfig> action: PayloadAction<ControlAdapterConfig>
) => { ) => {
caAdapter.addOne(state, action.payload); caAdapter.addOne(state, action.payload);
const { type, id } = action.payload;
disableIncompatibleControlAdapters(state, type, id);
}, },
controlAdapterDuplicated: { controlAdapterDuplicated: {
reducer: ( reducer: (
@ -196,8 +138,6 @@ export const controlAdaptersSlice = createSlice({
isEnabled: true, isEnabled: true,
}); });
caAdapter.addOne(state, newControlAdapter); caAdapter.addOne(state, newControlAdapter);
const { type } = newControlAdapter;
disableIncompatibleControlAdapters(state, type, newId);
}, },
prepare: (id: string) => { prepare: (id: string) => {
return { payload: { id, newId: uuidv4() } }; return { payload: { id, newId: uuidv4() } };
@ -217,7 +157,6 @@ export const controlAdaptersSlice = createSlice({
state, state,
buildControlAdapter(id, type, { controlImage }) buildControlAdapter(id, type, { controlImage })
); );
disableIncompatibleControlAdapters(state, type, id);
}, },
prepare: (payload: { prepare: (payload: {
type: ControlAdapterType; type: ControlAdapterType;
@ -235,12 +174,6 @@ export const controlAdaptersSlice = createSlice({
) => { ) => {
const { id, isEnabled } = action.payload; const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } }); caAdapter.updateOne(state, { id, changes: { isEnabled } });
if (isEnabled) {
// we are enabling a control adapter. due to limitations in the current system, we may need to disable other adapters
// TODO: disable when multiple IP adapters are supported
const ca = selectControlAdapterById(state, id);
ca && disableIncompatibleControlAdapters(state, ca.type, id);
}
}, },
controlAdapterImageChanged: ( controlAdapterImageChanged: (
state, state,

View File

@ -8808,11 +8808,11 @@ export type components = {
ui_order: number | null; ui_order: number | null;
}; };
/** /**
* StableDiffusionOnnxModelFormat * IPAdapterModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusionOnnxModelFormat: "olive" | "onnx"; IPAdapterModelFormat: "invokeai";
/** /**
* IPAdapterModelFormat * IPAdapterModelFormat
* @description An enumeration. * @description An enumeration.
@ -8832,11 +8832,11 @@ export type components = {
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* CLIPVisionModelFormat * StableDiffusionOnnxModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
CLIPVisionModelFormat: "diffusers"; StableDiffusionOnnxModelFormat: "olive" | "onnx";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.