add warning if no upscale model or no tile controlnet for base model

This commit is contained in:
Mary Hipp 2024-07-18 20:04:44 -04:00 committed by psychedelicious
parent d2bf3629bf
commit 5ab36e0433
5 changed files with 77 additions and 6 deletions

View File

@ -203,6 +203,10 @@ const createSelector = (templates: Templates) =>
if (!upscale.upscaleModel) { if (!upscale.upscaleModel) {
reasons.push({ content: "No upscale model selected" }) reasons.push({ content: "No upscale model selected" })
} }
if (!upscale.tileControlnetModel) {
reasons.push({ content: "No valid tile controlnet available" })
}
} else { } else {
// Handling for all other tabs // Handling for all other tabs
selectControlAdapterAll(controlAdapters) selectControlAdapterAll(controlAdapters)

View File

@ -3,6 +3,7 @@ import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isParamESRGANModelName } from 'features/parameters/store/postprocessingSlice'; import { isParamESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { CLIP_SKIP, CONTROL_NET_COLLECT, ESRGAN, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, RESIZE, SDXL_MODEL_LOADER, TILED_MULTI_DIFFUSION_DENOISE_LATENTS, UNSHARP_MASK, VAE_LOADER } from './constants'; import { CLIP_SKIP, CONTROL_NET_COLLECT, ESRGAN, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, RESIZE, SDXL_MODEL_LOADER, TILED_MULTI_DIFFUSION_DENOISE_LATENTS, UNSHARP_MASK, VAE_LOADER } from './constants';
import { addLoRAs } from './generation/addLoRAs'; import { addLoRAs } from './generation/addLoRAs';
import { addSDXLLoRas } from './generation/addSDXLLoRAs'; import { addSDXLLoRas } from './generation/addSDXLLoRAs';
@ -25,8 +26,8 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
assert(model, 'No model found in state'); assert(model, 'No model found in state');
assert(upscaleModel, 'No upscale model found in state'); assert(upscaleModel, 'No upscale model found in state');
assert(upscaleInitialImage, 'No initial image found in state'); assert(upscaleInitialImage, 'No initial image found in state');
assert(isParamESRGANModelName(upscaleModel.name), "") assert(isParamESRGANModelName(upscaleModel.name), "Model must be valid upscale model")
assert(scale) assert(scale, 'Scale is required')
const g = new Graph() const g = new Graph()

View File

@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterSpandrelImageToImageModel } from 'features/parameters/types/parameterSchemas';
import type { ImageDTO } from 'services/api/types'; import type { ControlNetModelConfig, ImageDTO } from 'services/api/types';
interface UpscaleState { interface UpscaleState {
@ -14,6 +14,7 @@ interface UpscaleState {
creativity: number; creativity: number;
tiledVAE: boolean; tiledVAE: boolean;
scale: number | null; scale: number | null;
tileControlnetModel: ControlNetModelConfig | null
} }
const initialUpscaleState: UpscaleState = { const initialUpscaleState: UpscaleState = {
@ -24,7 +25,8 @@ const initialUpscaleState: UpscaleState = {
structure: 0, structure: 0,
creativity: 0, creativity: 0,
tiledVAE: false, tiledVAE: false,
scale: null scale: null,
tileControlnetModel: null
}; };
export const upscaleSlice = createSlice({ export const upscaleSlice = createSlice({
@ -60,10 +62,13 @@ export const upscaleSlice = createSlice({
scaleChanged: (state, action: PayloadAction<number | null>) => { scaleChanged: (state, action: PayloadAction<number | null>) => {
state.scale = action.payload; state.scale = action.payload;
}, },
tileControlnetModelChanged: (state, action: PayloadAction<ControlNetModelConfig | null>) => {
state.tileControlnetModel = action.payload;
},
}, },
}); });
export const { upscaleModelChanged, upscaleInitialImageChanged, tiledVAEChanged, structureChanged, creativityChanged, sharpnessChanged, scaleChanged } = upscaleSlice.actions; export const { upscaleModelChanged, upscaleInitialImageChanged, tiledVAEChanged, structureChanged, creativityChanged, sharpnessChanged, scaleChanged, tileControlnetModelChanged } = upscaleSlice.actions;
export const selectUpscalelice = (state: RootState) => state.upscale; export const selectUpscalelice = (state: RootState) => state.upscale;

View File

@ -0,0 +1,59 @@
import { Flex, Link, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
import { useControlNetModels } from '../../../../services/api/hooks/modelsByType';
import { useCallback, useEffect, useMemo } from 'react';
import { tileControlnetModelChanged } from '../../../parameters/store/upscaleSlice';
import { MODEL_TYPE_SHORT_MAP } from '../../../parameters/types/constants';
import { setActiveTab } from '../../../ui/store/uiSlice';
export const MultidiffusionWarning = () => {
const model = useAppSelector((s) => s.generation.model);
const { tileControlnetModel, upscaleModel } = useAppSelector((s) => s.upscale);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useControlNetModels();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const shouldShowButton = useMemo(() => !disabledTabs.includes('models'), [disabledTabs]);
useEffect(() => {
const validModel = modelConfigs.find((cnetModel) => {
return cnetModel.base === model?.base && cnetModel.name.toLowerCase().includes('tile');
});
dispatch(tileControlnetModelChanged(validModel || null));
}, [model?.base, modelConfigs, dispatch]);
const warningText = useMemo(() => {
if (!model) {
return `a model`;
}
if (!upscaleModel && !tileControlnetModel) {
return `an upscaler model and ${MODEL_TYPE_SHORT_MAP[model.base]} tile controlnet`;
}
if (!upscaleModel) {
return 'an upscaler model';
}
if (!tileControlnetModel) {
return `a ${MODEL_TYPE_SHORT_MAP[model.base]} tile controlnet`;
}
}, [model?.base, upscaleModel, tileControlnetModel]);
const handleGoToModelManager = useCallback(() => {
dispatch(setActiveTab('models'));
}, [dispatch]);
if (!warningText || isLoading || !shouldShowButton) {
return <></>;
}
return (
<Flex bg="error.500" borderRadius={'base'} padding="2" direction="column">
<Text fontSize="xs" textAlign="center" display={'inline-block'}>
Visit{' '}
<Link fontWeight="bold" onClick={handleGoToModelManager}>
Model Manager
</Link>{' '}
to install {warningText} required by this feature
</Text>
</Flex>
);
};

View File

@ -5,6 +5,7 @@ import ParamCreativity from 'features/parameters/components/Upscale/ParamCreativ
import ParamSharpness from 'features/parameters/components/Upscale/ParamSharpness'; import ParamSharpness from 'features/parameters/components/Upscale/ParamSharpness';
import ParamSpandrelModel from 'features/parameters/components/Upscale/ParamSpandrelModel'; import ParamSpandrelModel from 'features/parameters/components/Upscale/ParamSpandrelModel';
import ParamStructure from 'features/parameters/components/Upscale/ParamStructure'; import ParamStructure from 'features/parameters/components/Upscale/ParamStructure';
import { ParamTiledVAEToggle } from 'features/parameters/components/Upscale/ParamTiledVAEToggle';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice'; import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
@ -13,7 +14,7 @@ import { useTranslation } from 'react-i18next';
import { UpscaleInitialImage } from './UpscaleInitialImage'; import { UpscaleInitialImage } from './UpscaleInitialImage';
import { UpscaleSizeDetails } from './UpscaleSizeDetails'; import { UpscaleSizeDetails } from './UpscaleSizeDetails';
import { ParamTiledVAEToggle } from '../../../parameters/components/Upscale/ParamTiledVAEToggle'; import { MultidiffusionWarning } from './MultidiffusionWarning';
const selector = createMemoizedSelector([selectUpscalelice], (upscale) => { const selector = createMemoizedSelector([selectUpscalelice], (upscale) => {
const badges: string[] = []; const badges: string[] = [];
@ -46,6 +47,7 @@ export const UpscaleSettingsAccordion = memo(() => {
<Flex direction="column" w="full" alignItems="center" gap={4}> <Flex direction="column" w="full" alignItems="center" gap={4}>
<ParamSpandrelModel /> <ParamSpandrelModel />
<UpscaleSizeDetails /> <UpscaleSizeDetails />
<MultidiffusionWarning />
</Flex> </Flex>
</Flex> </Flex>
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}> <Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>