mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix controlNet models
- update controlnet state to use object format for model - update model-parsing helper functions to log errors - update nodes components, types and state - remove controlnets from state when models are loaded and the controlnet's model is not available
This commit is contained in:
@ -90,7 +90,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
||||
<ParamControlNetModel controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
|
@ -1,36 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import {
|
||||
controlNetToggled,
|
||||
isControlNetImagePreprocessedToggled,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetIsEnabledProps = {
|
||||
controlNetId: string;
|
||||
isControlImageProcessed: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||
const { controlNetId, isControlImageProcessed } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleIsControlImageProcessedToggled = useCallback(() => {
|
||||
dispatch(
|
||||
isControlNetImagePreprocessedToggled({
|
||||
controlNetId,
|
||||
})
|
||||
);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Preprocess"
|
||||
isChecked={isControlImageProcessed}
|
||||
onChange={handleIsControlImageProcessedToggled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetIsEnabled);
|
@ -1,39 +1,45 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
model: string;
|
||||
};
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const currentMainModel = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ generation, controlNet }) => {
|
||||
const { model } = generation;
|
||||
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
||||
return { mainModel: model, controlNetModel };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { mainModel, controlNetModel } = useAppSelector(selector);
|
||||
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(val: string | null) => {
|
||||
if (!val) return;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model: val }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
@ -46,7 +52,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = currentMainModel?.base_model !== model.base_model;
|
||||
const disabled = model?.base_model !== mainModel?.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
@ -60,16 +66,52 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels, currentMainModel?.base_model]);
|
||||
}, [controlNetModels, mainModel?.base_model]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
controlNetModels?.entities[
|
||||
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
controlNetModel?.base_model,
|
||||
controlNetModel?.model_name,
|
||||
controlNetModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||
|
||||
if (!newControlNetModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlNetModelChanged({ controlNetId, model: newControlNetModel })
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
data={data}
|
||||
value={model}
|
||||
error={
|
||||
!selectedModel || mainModel?.base_model !== selectedModel.base_model
|
||||
}
|
||||
placeholder="Select a model"
|
||||
value={selectedModel?.id ?? null}
|
||||
onChange={handleModelChanged}
|
||||
disabled={!isReady}
|
||||
tooltip={model}
|
||||
disabled={isBusy}
|
||||
tooltip={selectedModel?.description}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,23 +1,20 @@
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
// CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
// ControlNetModelName,
|
||||
} from './constants';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
@ -27,7 +24,7 @@ export type ControlModes =
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: '',
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
@ -43,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
export type ControlNetConfig = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
model: string;
|
||||
model: ControlNetModelParam | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -148,7 +145,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
model: string;
|
||||
model: ControlNetModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, model } = action.payload;
|
||||
@ -159,7 +156,7 @@ export const controlNetSlice = createSlice({
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (model.includes(modelSubstring)) {
|
||||
if (model.model_name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
@ -253,7 +250,11 @@ export const controlNetSlice = createSlice({
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (state.controlNets[controlNetId].model.includes(modelSubstring)) {
|
||||
if (
|
||||
state.controlNets[controlNetId].model?.model_name.includes(
|
||||
modelSubstring
|
||||
)
|
||||
) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
@ -287,7 +288,8 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
// Preemptively remove the image from the gallery
|
||||
// Preemptively remove the image from all controlnets
|
||||
// TODO: doesn't the imageusage stuff do this for us?
|
||||
const { image_name } = action.meta.arg;
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage === image_name) {
|
||||
@ -300,21 +302,6 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
// forEach(state.controlNets, (c) => {
|
||||
// if (c.controlImage?.image_name === image_name) {
|
||||
// c.controlImage.image_url = image_url;
|
||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// if (c.processedControlImage?.image_name === image_name) {
|
||||
// c.processedControlImage.image_url = image_url;
|
||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
Reference in New Issue
Block a user