diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 57981918d8..c3f710c09c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -9,6 +9,10 @@ import { zMainModel, zVaeModel, } from 'features/parameters/types/parameterSchemas'; +import { + refinerModelChanged, + setIsRefinerAvailable, +} from 'features/sdxl/store/sdxlSlice'; import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; @@ -59,6 +63,53 @@ export const addModelsLoadedListener = () => { dispatch(modelChanged(result.data)); }, }); + startAppListening({ + matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled, + effect: async (action, { getState, dispatch }) => { + // models loaded, we need to ensure the selected model is available and if not, select the first one + const log = logger('models'); + log.info( + { models: action.payload.entities }, + `SDXL Refiner models loaded (${action.payload.ids.length})` + ); + + const currentModel = getState().sdxl.refinerModel; + + const isCurrentModelAvailable = some( + action.payload.entities, + (m) => + m?.model_name === currentModel?.model_name && + m?.base_model === currentModel?.base_model + ); + + if (isCurrentModelAvailable) { + return; + } + + const firstModelId = action.payload.ids[0]; + const firstModel = action.payload.entities[firstModelId]; + + if (!firstModel) { + // No models loaded at all + dispatch(refinerModelChanged(null)); + dispatch(setIsRefinerAvailable(false)); + return; + } + + const result = zMainModel.safeParse(firstModel); + + if (!result.success) { + log.error( + { error: result.error.format() }, + 'Failed to parse SDXL Refiner Model' + ); + return; + } + + dispatch(refinerModelChanged(result.data)); + dispatch(setIsRefinerAvailable(true)); + }, + }); startAppListening({ matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, effect: async (action, { getState, dispatch }) => { diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx index bab9011eb2..8210bd0385 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx @@ -9,10 +9,11 @@ import { memo, useCallback } from 'react'; const selector = createSelector( [stateSelector], ({ sdxl, hotkeys }) => { - const { refinerAestheticScore } = sdxl; + const { refinerAestheticScore, isRefinerAvailable } = sdxl; const { shift } = hotkeys; return { + isRefinerAvailable, refinerAestheticScore, shift, }; @@ -21,7 +22,8 @@ const selector = createSelector( ); const ParamSDXLRefinerAestheticScore = () => { - const { refinerAestheticScore, shift } = useAppSelector(selector); + const { refinerAestheticScore, shift, isRefinerAvailable } = + useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback( @@ -48,6 +50,7 @@ const ParamSDXLRefinerAestheticScore = () => { withReset withSliderMarks isInteger={false} + isDisabled={!isRefinerAvailable} /> ); }; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx index 3e5bb3b9b5..3371d4048e 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx @@ -11,11 +11,12 @@ import { useTranslation } from 'react-i18next'; const selector = createSelector( [stateSelector], ({ sdxl, ui, hotkeys }) => { - const { refinerCFGScale } = sdxl; + const { refinerCFGScale, isRefinerAvailable } = sdxl; const { shouldUseSliders } = ui; const { shift } = hotkeys; return { + isRefinerAvailable, refinerCFGScale, shouldUseSliders, shift, @@ -25,7 +26,8 @@ const selector = createSelector( ); const ParamSDXLRefinerCFGScale = () => { - const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector); + const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } = + useAppSelector(selector); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -53,6 +55,7 @@ const ParamSDXLRefinerCFGScale = () => { withReset withSliderMarks isInteger={false} + isDisabled={!isRefinerAvailable} /> ) : ( { value={refinerCFGScale} isInteger={false} numberInputFieldProps={{ textAlign: 'center' }} + isDisabled={!isRefinerAvailable} /> ); }; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 1b623e23f8..5c052af562 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -11,7 +11,7 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models'; const selector = createSelector( stateSelector, @@ -24,39 +24,37 @@ const ParamSDXLRefinerModelSelect = () => { const { model } = useAppSelector(selector); - const { data: mainModels, isLoading } = useGetMainModelsQuery(); + const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery(); const data = useMemo(() => { - if (!mainModels) { + if (!sdxlModels) { return []; } const data: SelectItem[] = []; - forEach(mainModels.entities, (model, id) => { + forEach(sdxlModels.entities, (model, id) => { if (!model) { return; } - if (['sdxl-refiner'].includes(model.base_model)) { - data.push({ - value: id, - label: model.model_name, - group: MODEL_TYPE_MAP[model.base_model], - }); - } + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + }); }); return data; - }, [mainModels]); + }, [sdxlModels]); // grab the full model entity from the RTK Query cache // TODO: maybe we should just store the full model entity in state? const selectedModel = useMemo( () => - mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ?? + sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ?? null, - [mainModels?.entities, model] + [sdxlModels?.entities, model] ); const handleChangeModel = useCallback( diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx index fd2b1457d5..8fd62df176 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx @@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next'; const selector = createSelector( stateSelector, ({ ui, sdxl }) => { - const { refinerScheduler } = sdxl; + const { refinerScheduler, isRefinerAvailable } = sdxl; const { favoriteSchedulers: enabledSchedulers } = ui; const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ @@ -27,6 +27,7 @@ const selector = createSelector( })).sort((a, b) => a.label.localeCompare(b.label)); return { + isRefinerAvailable, refinerScheduler, data, }; @@ -37,7 +38,8 @@ const selector = createSelector( const ParamSDXLRefinerScheduler = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { refinerScheduler, data } = useAppSelector(selector); + const { refinerScheduler, data, isRefinerAvailable } = + useAppSelector(selector); const handleChange = useCallback( (v: string | null) => { @@ -56,6 +58,7 @@ const ParamSDXLRefinerScheduler = () => { value={refinerScheduler} data={data} onChange={handleChange} + disabled={!isRefinerAvailable} /> ); }; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx index e4b590e988..df4512691e 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx @@ -9,10 +9,11 @@ import { memo, useCallback } from 'react'; const selector = createSelector( [stateSelector], ({ sdxl, hotkeys }) => { - const { refinerStart } = sdxl; + const { refinerStart, isRefinerAvailable } = sdxl; const { shift } = hotkeys; return { + isRefinerAvailable, refinerStart, shift, }; @@ -21,7 +22,7 @@ const selector = createSelector( ); const ParamSDXLRefinerStart = () => { - const { refinerStart, shift } = useAppSelector(selector); + const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback( @@ -48,6 +49,7 @@ const ParamSDXLRefinerStart = () => { withReset withSliderMarks isInteger={false} + isDisabled={!isRefinerAvailable} /> ); }; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx index c3e61b70e4..07b0f4abfe 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx @@ -12,10 +12,11 @@ import { useTranslation } from 'react-i18next'; const selector = createSelector( [stateSelector], ({ sdxl, ui }) => { - const { refinerSteps } = sdxl; + const { refinerSteps, isRefinerAvailable } = sdxl; const { shouldUseSliders } = ui; return { + isRefinerAvailable, refinerSteps, shouldUseSliders, }; @@ -24,7 +25,8 @@ const selector = createSelector( ); const ParamSDXLRefinerSteps = () => { - const { refinerSteps, shouldUseSliders } = useAppSelector(selector); + const { refinerSteps, shouldUseSliders, isRefinerAvailable } = + useAppSelector(selector); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -51,6 +53,7 @@ const ParamSDXLRefinerSteps = () => { withReset withSliderMarks sliderNumberInputProps={{ max: 500 }} + isDisabled={!isRefinerAvailable} /> ) : ( { onChange={handleChange} value={refinerSteps} numberInputFieldProps={{ textAlign: 'center' }} + isDisabled={!isRefinerAvailable} /> ); }; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx index 9da8286910..1f33d8d8a3 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx @@ -24,7 +24,7 @@ export default function ParamUseSDXLRefiner() { label="Use Refiner" isChecked={shouldUseSDXLRefiner} onChange={handleUseSDXLRefinerChange} - isDisabled={isRefinerAvailable} + isDisabled={!isRefinerAvailable} /> ); } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index ff82bc2802..f1ff731d0e 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -107,6 +107,9 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query']; const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); +const sdxlRefinerModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +}); const loraModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -176,6 +179,43 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + getSDXLRefinerModels: build.query, void>( + { + query: () => ({ + url: 'models/', + params: { model_type: 'main', base_models: 'sdxl-refiner' }, + }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'SDXLRefinerModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: MainModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return sdxlRefinerModelsAdapter.setAll( + sdxlRefinerModelsAdapter.getInitialState(), + entities + ); + }, + } + ), updateMainModels: build.mutation< UpdateMainModelResponse, UpdateMainModelArg @@ -187,7 +227,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), importMainModels: build.mutation< ImportMainModelResponse, @@ -200,7 +243,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), addMainModels: build.mutation({ query: ({ body }) => { @@ -210,7 +256,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), deleteMainModels: build.mutation< DeleteMainModelResponse, @@ -222,7 +271,10 @@ export const modelsApi = api.injectEndpoints({ method: 'DELETE', }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), convertMainModels: build.mutation< ConvertMainModelResponse, @@ -235,7 +287,10 @@ export const modelsApi = api.injectEndpoints({ params: params, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), mergeMainModels: build.mutation({ query: ({ base_model, body }) => { @@ -245,7 +300,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), syncModels: build.mutation({ query: () => { @@ -254,7 +312,10 @@ export const modelsApi = api.injectEndpoints({ method: 'POST', }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), @@ -425,6 +486,7 @@ export const modelsApi = api.injectEndpoints({ export const { useGetMainModelsQuery, + useGetSDXLRefinerModelsQuery, useGetControlNetModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery,