feat(ui): fix controlNet models

- update controlnet state to use object format for model
- update model-parsing helper functions to log errors
- update nodes components, types and state
- remove controlnets from state when models are loaded and the controlnet's model is not available
This commit is contained in:
psychedelicious 2023-07-15 12:19:24 +10:00
parent 76dc47e88d
commit 0d41346417
18 changed files with 249 additions and 154 deletions

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1; modelsCleared += 1;
} }
// TODO: handle incompatible controlnet; pending model manager support const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state // ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
}, },
}); });
}; };

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.'); reasonsWhyNotReady.push('Seed-Weights badly formatted.');
} }
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good // All good
return { isReady, reasonsWhyNotReady }; return { isReady, reasonsWhyNotReady };
}, },

View File

@ -90,7 +90,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}
> >
<ParamControlNetModel controlNetId={controlNetId} model={model} /> <ParamControlNetModel controlNetId={controlNetId} />
</Box> </Box>
<IAIIconButton <IAIIconButton
size="sm" size="sm"

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,39 +1,45 @@
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: string;
}; };
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const currentMainModel = useAppSelector( const selector = useMemo(
(state: RootState) => state.generation.model () =>
createSelector(
stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
return { mainModel: model, controlNetModel };
},
defaultSelectorOptions
),
[controlNetId]
); );
const { mainModel, controlNetModel } = useAppSelector(selector);
const { data: controlNetModels } = useGetControlNetModelsQuery(); const { data: controlNetModels } = useGetControlNetModelsQuery();
const handleModelChanged = useCallback(
(val: string | null) => {
if (!val) return;
dispatch(controlNetModelChanged({ controlNetId, model: val }));
},
[controlNetId, dispatch]
);
const data = useMemo(() => { const data = useMemo(() => {
if (!controlNetModels) { if (!controlNetModels) {
return []; return [];
@ -46,7 +52,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
return; return;
} }
const disabled = currentMainModel?.base_model !== model.base_model; const disabled = model?.base_model !== mainModel?.base_model;
data.push({ data.push({
value: id, value: id,
@ -60,16 +66,52 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
}); });
return data; return data;
}, [controlNetModels, currentMainModel?.base_model]); }, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
},
[controlNetId, dispatch]
);
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
data={data} data={data}
value={model} error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged} onChange={handleModelChanged}
disabled={!isReady} disabled={isBusy}
tooltip={model} tooltip={selectedModel?.description}
/> />
); );
}; };

View File

@ -1,23 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation, RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
// ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes = export type ControlModes =
| 'balanced' | 'balanced'
@ -27,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: '', model: null,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -43,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: string; model: ControlNetModelParam | null;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -148,7 +145,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
model: string; model: ControlNetModelParam;
}> }>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
@ -159,7 +156,7 @@ export const controlNetSlice = createSlice({
let processorType: ControlNetProcessorType | undefined = undefined; let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.includes(modelSubstring)) { if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break; break;
} }
@ -253,7 +250,11 @@ export const controlNetSlice = createSlice({
let processorType: ControlNetProcessorType | undefined = undefined; let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (state.controlNets[controlNetId].model.includes(modelSubstring)) { if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break; break;
} }
@ -287,7 +288,8 @@ export const controlNetSlice = createSlice({
}); });
builder.addCase(imageDeleted.pending, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery // Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg; const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => { forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) { if (c.controlImage === image_name) {
@ -300,21 +302,6 @@ export const controlNetSlice = createSlice({
}); });
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => { builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}); });

View File

@ -6,9 +6,10 @@ import {
ControlNetModelInputFieldTemplate, ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue, ControlNetModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach, isString } from 'lodash-es'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
@ -20,15 +21,23 @@ const ControlNetModelInputFieldComponent = (
> >
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery(); const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo( const selectedModel = useMemo(
() => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]], () =>
[controlNetModels?.entities, controlNetModels?.ids, field.value] controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
); );
const data = useMemo(() => { const data = useMemo(() => {
@ -45,8 +54,8 @@ const ControlNetModelInputFieldComponent = (
data.push({ data.push({
value: id, value: id,
label: model.name, label: model.model_name,
group: BASE_MODEL_NAME_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
}); });
}); });
@ -59,40 +68,32 @@ const ControlNetModelInputFieldComponent = (
return; return;
} }
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: v, value: newControlNetModel,
}) })
); );
}, },
[dispatch, field.name, nodeId] [dispatch, field.name, nodeId]
); );
useEffect(() => {
if (field.value && controlNetModels?.ids.includes(field.value)) {
return;
}
const firstLora = controlNetModels?.ids[0];
if (!isString(firstLora)) {
return;
}
handleValueChanged(firstLora);
}, [field.value, handleValueChanged, controlNetModels?.ids]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
} }
value={field.value} value={selectedModel?.id ?? null}
placeholder="Pick one" placeholder="Pick one"
error={!selectedModel}
data={data} data={data}
onChange={handleValueChanged} onChange={handleValueChanged}
/> />

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[] | ImageField[]
| MainModelParam | MainModelParam
| VaeModelParam | VaeModelParam
| LoRAModelParam; | LoRAModelParam
| ControlNetModelParam;
}> }>
) => { ) => {
const { nodeId, fieldName, value } = action.payload; const { nodeId, fieldName, value } = action.payload;

View File

@ -1,4 +1,5 @@
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -254,7 +255,7 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
export type ControlNetModelInputFieldValue = FieldValueBase & { export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model'; type: 'controlnet_model';
value?: string; value?: ControlNetModelParam;
}; };
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {

View File

@ -1,14 +0,0 @@
import { BaseModelType, ControlNetModelField } from 'services/api/types';
export const modelIdToControlNetModelField = (
controlNetModelId: string
): ControlNetModelField => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const field: ControlNetModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return []; return [];
} }
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [ const data: SelectItem[] = [
{ {
value: 'default', value: 'default',

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/ */
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success; zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = ( export const modelIdToLoRAModelParam = (
loraId: string loraModelId: string
): LoRAModelParam | undefined => { ): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/'); const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({ const result = zLoRAModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return; return;
} }

View File

@ -2,11 +2,14 @@ import {
MainModelParam, MainModelParam,
zMainModel, zMainModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = ( export const modelIdToMainModelParam = (
modelId: string mainModelId: string
): MainModelParam | undefined => { ): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({ const result = zMainModel.safeParse({
base_model, base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return; return;
} }

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = ( export const modelIdToVAEModelParam = (
modelId: string vaeModelId: string
): VaeModelParam | undefined => { ): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({ const result = zVaeModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return; return;
} }

View File

@ -1935,12 +1935,12 @@ export type components = {
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
*/ */
height: number; height?: number;
/** /**
* Resample Mode * Resample Mode
* @description The resampling mode * @description The resampling mode
@ -3302,7 +3302,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -3922,14 +3922,16 @@ export type components = {
latents?: components["schemas"]["LatentsField"]; latents?: components["schemas"]["LatentsField"];
/** /**
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
* @default 512
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
* @default 512
*/ */
height: number; height?: number;
/** /**
* Mode * Mode
* @description The interpolation mode * @description The interpolation mode
@ -5009,7 +5011,7 @@ export type operations = {
/** @description The model imported successfully */ /** @description The model imported successfully */
201: { 201: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description The model could not be found */ /** @description The model could not be found */
@ -5077,14 +5079,14 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
responses: { responses: {
/** @description The model was updated successfully */ /** @description The model was updated successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5118,7 +5120,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5153,7 +5155,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Incompatible models */ /** @description Incompatible models */