mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix controlNet models
- update controlnet state to use object format for model - update model-parsing helper functions to log errors - update nodes components, types and state - remove controlnets from state when models are loaded and the controlnet's model is not available
This commit is contained in:
parent
76dc47e88d
commit
0d41346417
@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ module: 'models' });
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
|
|||||||
modelsCleared += 1;
|
modelsCleared += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle incompatible controlnet; pending model manager support
|
const { controlNets } = state.controlNet;
|
||||||
|
forEach(controlNets, (controlNet, controlNetId) => {
|
||||||
|
if (controlNet.model?.base_model !== base_model) {
|
||||||
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
|
modelsCleared += 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (modelsCleared > 0) {
|
if (modelsCleared > 0) {
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ module: 'models' });
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
|
|||||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||||
// TODO: pending model manager controlnet support
|
const controlNets = getState().controlNet.controlNets;
|
||||||
|
|
||||||
|
forEach(controlNets, (controlNet, controlNetId) => {
|
||||||
|
const isControlNetAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === controlNet?.model?.model_name &&
|
||||||
|
m?.base_model === controlNet?.model?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isControlNetAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { modelsApi } from '../../services/api/endpoints/models';
|
import { modelsApi } from '../../services/api/endpoints/models';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
|
||||||
const readinessSelector = createSelector(
|
const readinessSelector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
|
|||||||
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
||||||
|
if (!controlNet.model) {
|
||||||
|
isReady = false;
|
||||||
|
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// All good
|
// All good
|
||||||
return { isReady, reasonsWhyNotReady };
|
return { isReady, reasonsWhyNotReady };
|
||||||
},
|
},
|
||||||
|
@ -90,7 +90,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
<ParamControlNetModel controlNetId={controlNetId} />
|
||||||
</Box>
|
</Box>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
|
||||||
import {
|
|
||||||
controlNetToggled,
|
|
||||||
isControlNetImagePreprocessedToggled,
|
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
|
|
||||||
type ParamControlNetIsEnabledProps = {
|
|
||||||
controlNetId: string;
|
|
||||||
isControlImageProcessed: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
|
||||||
const { controlNetId, isControlImageProcessed } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleIsControlImageProcessedToggled = useCallback(() => {
|
|
||||||
dispatch(
|
|
||||||
isControlNetImagePreprocessedToggled({
|
|
||||||
controlNetId,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}, [controlNetId, dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<IAISwitch
|
|
||||||
label="Preprocess"
|
|
||||||
isChecked={isControlImageProcessed}
|
|
||||||
onChange={handleIsControlImageProcessedToggled}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ParamControlNetIsEnabled);
|
|
@ -1,39 +1,45 @@
|
|||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { RootState } from 'app/store/store';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId, model } = props;
|
const { controlNetId } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const currentMainModel = useAppSelector(
|
const selector = useMemo(
|
||||||
(state: RootState) => state.generation.model
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ generation, controlNet }) => {
|
||||||
|
const { model } = generation;
|
||||||
|
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
||||||
|
return { mainModel: model, controlNetModel };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { mainModel, controlNetModel } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
const handleModelChanged = useCallback(
|
|
||||||
(val: string | null) => {
|
|
||||||
if (!val) return;
|
|
||||||
dispatch(controlNetModelChanged({ controlNetId, model: val }));
|
|
||||||
},
|
|
||||||
[controlNetId, dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!controlNetModels) {
|
if (!controlNetModels) {
|
||||||
return [];
|
return [];
|
||||||
@ -46,7 +52,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const disabled = currentMainModel?.base_model !== model.base_model;
|
const disabled = model?.base_model !== mainModel?.base_model;
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
@ -60,16 +66,52 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [controlNetModels, currentMainModel?.base_model]);
|
}, [controlNetModels, mainModel?.base_model]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
controlNetModels?.entities[
|
||||||
|
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
controlNetModel?.base_model,
|
||||||
|
controlNetModel?.model_name,
|
||||||
|
controlNetModels?.entities,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleModelChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||||
|
|
||||||
|
if (!newControlNetModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetModelChanged({ controlNetId, model: newControlNetModel })
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
data={data}
|
data={data}
|
||||||
value={model}
|
error={
|
||||||
|
!selectedModel || mainModel?.base_model !== selectedModel.base_model
|
||||||
|
}
|
||||||
|
placeholder="Select a model"
|
||||||
|
value={selectedModel?.id ?? null}
|
||||||
onChange={handleModelChanged}
|
onChange={handleModelChanged}
|
||||||
disabled={!isReady}
|
disabled={isBusy}
|
||||||
tooltip={model}
|
tooltip={selectedModel?.description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,23 +1,20 @@
|
|||||||
import { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { imageDeleted } from 'services/api/thunks/image';
|
||||||
|
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||||
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
|
import { controlNetImageProcessed } from './actions';
|
||||||
|
import {
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
} from './constants';
|
||||||
import {
|
import {
|
||||||
ControlNetProcessorType,
|
ControlNetProcessorType,
|
||||||
RequiredCannyImageProcessorInvocation,
|
RequiredCannyImageProcessorInvocation,
|
||||||
RequiredControlNetProcessorNode,
|
RequiredControlNetProcessorNode,
|
||||||
} from './types';
|
} from './types';
|
||||||
import {
|
|
||||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
|
||||||
// CONTROLNET_MODELS,
|
|
||||||
CONTROLNET_PROCESSORS,
|
|
||||||
// ControlNetModelName,
|
|
||||||
} from './constants';
|
|
||||||
import { controlNetImageProcessed } from './actions';
|
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
|
||||||
|
|
||||||
export type ControlModes =
|
export type ControlModes =
|
||||||
| 'balanced'
|
| 'balanced'
|
||||||
@ -27,7 +24,7 @@ export type ControlModes =
|
|||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: '',
|
model: null,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
@ -43,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
|||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
model: string;
|
model: ControlNetModelParam | null;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
@ -148,7 +145,7 @@ export const controlNetSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: string;
|
model: ControlNetModelParam;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, model } = action.payload;
|
const { controlNetId, model } = action.payload;
|
||||||
@ -159,7 +156,7 @@ export const controlNetSlice = createSlice({
|
|||||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
if (model.includes(modelSubstring)) {
|
if (model.model_name.includes(modelSubstring)) {
|
||||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -253,7 +250,11 @@ export const controlNetSlice = createSlice({
|
|||||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
if (state.controlNets[controlNetId].model.includes(modelSubstring)) {
|
if (
|
||||||
|
state.controlNets[controlNetId].model?.model_name.includes(
|
||||||
|
modelSubstring
|
||||||
|
)
|
||||||
|
) {
|
||||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -287,7 +288,8 @@ export const controlNetSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||||
// Preemptively remove the image from the gallery
|
// Preemptively remove the image from all controlnets
|
||||||
|
// TODO: doesn't the imageusage stuff do this for us?
|
||||||
const { image_name } = action.meta.arg;
|
const { image_name } = action.meta.arg;
|
||||||
forEach(state.controlNets, (c) => {
|
forEach(state.controlNets, (c) => {
|
||||||
if (c.controlImage === image_name) {
|
if (c.controlImage === image_name) {
|
||||||
@ -300,21 +302,6 @@ export const controlNetSlice = createSlice({
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
// forEach(state.controlNets, (c) => {
|
|
||||||
// if (c.controlImage?.image_name === image_name) {
|
|
||||||
// c.controlImage.image_url = image_url;
|
|
||||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// if (c.processedControlImage?.image_name === image_name) {
|
|
||||||
// c.processedControlImage.image_url = image_url;
|
|
||||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
|
|
||||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
});
|
});
|
||||||
|
@ -6,9 +6,10 @@ import {
|
|||||||
ControlNetModelInputFieldTemplate,
|
ControlNetModelInputFieldTemplate,
|
||||||
ControlNetModelInputFieldValue,
|
ControlNetModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
@ -20,15 +21,23 @@ const ControlNetModelInputFieldComponent = (
|
|||||||
>
|
>
|
||||||
) => {
|
) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
|
const controlNetModel = field.value;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]],
|
() =>
|
||||||
[controlNetModels?.entities, controlNetModels?.ids, field.value]
|
controlNetModels?.entities[
|
||||||
|
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
controlNetModel?.base_model,
|
||||||
|
controlNetModel?.model_name,
|
||||||
|
controlNetModels?.entities,
|
||||||
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
@ -45,8 +54,8 @@ const ControlNetModelInputFieldComponent = (
|
|||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.name,
|
label: model.model_name,
|
||||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -59,40 +68,32 @@ const ControlNetModelInputFieldComponent = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||||
|
|
||||||
|
if (!newControlNetModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName: field.name,
|
fieldName: field.name,
|
||||||
value: v,
|
value: newControlNetModel,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (field.value && controlNetModels?.ids.includes(field.value)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstLora = controlNetModels?.ids[0];
|
|
||||||
|
|
||||||
if (!isString(firstLora)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
handleValueChanged(firstLora);
|
|
||||||
}, [field.value, handleValueChanged, controlNetModels?.ids]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={
|
label={
|
||||||
selectedModel?.base_model &&
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
|
||||||
}
|
}
|
||||||
value={field.value}
|
value={selectedModel?.id ?? null}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
|
error={!selectedModel}
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
/>
|
/>
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
ControlNetModelParam,
|
||||||
LoRAModelParam,
|
LoRAModelParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
|
|||||||
| ImageField[]
|
| ImageField[]
|
||||||
| MainModelParam
|
| MainModelParam
|
||||||
| VaeModelParam
|
| VaeModelParam
|
||||||
| LoRAModelParam;
|
| LoRAModelParam
|
||||||
|
| ControlNetModelParam;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { nodeId, fieldName, value } = action.payload;
|
const { nodeId, fieldName, value } = action.payload;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
ControlNetModelParam,
|
||||||
LoRAModelParam,
|
LoRAModelParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
@ -254,7 +255,7 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
|||||||
|
|
||||||
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||||
type: 'controlnet_model';
|
type: 'controlnet_model';
|
||||||
value?: string;
|
value?: ControlNetModelParam;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
import { BaseModelType, ControlNetModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
export const modelIdToControlNetModelField = (
|
|
||||||
controlNetModelId: string
|
|
||||||
): ControlNetModelField => {
|
|
||||||
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
|
||||||
|
|
||||||
const field: ControlNetModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add a "default" option, this means use the main model's included VAE
|
||||||
const data: SelectItem[] = [
|
const data: SelectItem[] = [
|
||||||
{
|
{
|
||||||
value: 'default',
|
value: 'default',
|
||||||
|
@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
|||||||
*/
|
*/
|
||||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||||
zLoRAModel.safeParse(val).success;
|
zLoRAModel.safeParse(val).success;
|
||||||
|
/**
|
||||||
|
* Zod schema for ControlNet models
|
||||||
|
*/
|
||||||
|
export const zControlNetModel = z.object({
|
||||||
|
model_name: z.string().min(1),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a model parameter
|
||||||
|
*/
|
||||||
|
export const isValidControlNetModel = (
|
||||||
|
val: unknown
|
||||||
|
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { ControlNetModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
|
export const modelIdToControlNetModelParam = (
|
||||||
|
controlNetModelId: string
|
||||||
|
): ControlNetModelField | undefined => {
|
||||||
|
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||||
|
|
||||||
|
const result = zControlNetModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
controlNetModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse ControlNet model id'
|
||||||
|
);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -1,9 +1,12 @@
|
|||||||
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToLoRAModelParam = (
|
export const modelIdToLoRAModelParam = (
|
||||||
loraId: string
|
loraModelId: string
|
||||||
): LoRAModelParam | undefined => {
|
): LoRAModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = loraId.split('/');
|
const [base_model, model_type, model_name] = loraModelId.split('/');
|
||||||
|
|
||||||
const result = zLoRAModel.safeParse({
|
const result = zLoRAModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
loraModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse LoRA model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,14 @@ import {
|
|||||||
MainModelParam,
|
MainModelParam,
|
||||||
zMainModel,
|
zMainModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToMainModelParam = (
|
export const modelIdToMainModelParam = (
|
||||||
modelId: string
|
mainModelId: string
|
||||||
): MainModelParam | undefined => {
|
): MainModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||||
|
|
||||||
const result = zMainModel.safeParse({
|
const result = zMainModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
mainModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse main model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToVAEModelParam = (
|
export const modelIdToVAEModelParam = (
|
||||||
modelId: string
|
vaeModelId: string
|
||||||
): VaeModelParam | undefined => {
|
): VaeModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
const [base_model, model_type, model_name] = vaeModelId.split('/');
|
||||||
|
|
||||||
const result = zVaeModel.safeParse({
|
const result = zVaeModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
vaeModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse VAE model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1935,12 +1935,12 @@ export type components = {
|
|||||||
* Width
|
* Width
|
||||||
* @description The width to resize to (px)
|
* @description The width to resize to (px)
|
||||||
*/
|
*/
|
||||||
width: number;
|
width?: number;
|
||||||
/**
|
/**
|
||||||
* Height
|
* Height
|
||||||
* @description The height to resize to (px)
|
* @description The height to resize to (px)
|
||||||
*/
|
*/
|
||||||
height: number;
|
height?: number;
|
||||||
/**
|
/**
|
||||||
* Resample Mode
|
* Resample Mode
|
||||||
* @description The resampling mode
|
* @description The resampling mode
|
||||||
@ -3302,7 +3302,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -3922,14 +3922,16 @@ export type components = {
|
|||||||
latents?: components["schemas"]["LatentsField"];
|
latents?: components["schemas"]["LatentsField"];
|
||||||
/**
|
/**
|
||||||
* Width
|
* Width
|
||||||
* @description The width to resize to (px)
|
* @description The width to resize to (px)
|
||||||
|
* @default 512
|
||||||
*/
|
*/
|
||||||
width: number;
|
width?: number;
|
||||||
/**
|
/**
|
||||||
* Height
|
* Height
|
||||||
* @description The height to resize to (px)
|
* @description The height to resize to (px)
|
||||||
|
* @default 512
|
||||||
*/
|
*/
|
||||||
height: number;
|
height?: number;
|
||||||
/**
|
/**
|
||||||
* Mode
|
* Mode
|
||||||
* @description The interpolation mode
|
* @description The interpolation mode
|
||||||
@ -5009,7 +5011,7 @@ export type operations = {
|
|||||||
/** @description The model imported successfully */
|
/** @description The model imported successfully */
|
||||||
201: {
|
201: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description The model could not be found */
|
/** @description The model could not be found */
|
||||||
@ -5077,14 +5079,14 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description The model was updated successfully */
|
/** @description The model was updated successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5118,7 +5120,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -5153,7 +5155,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
|
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Incompatible models */
|
/** @description Incompatible models */
|
||||||
|
Loading…
Reference in New Issue
Block a user