feat: Move SDXL Refiner to own route & set appropriate disabled statuses

This commit is contained in:
blessedcoolant 2023-07-25 22:14:19 +12:00 committed by psychedelicious
parent 8d1b8179af
commit 5202610160
9 changed files with 159 additions and 32 deletions

View File

@ -9,6 +9,10 @@ import {
zMainModel, zMainModel,
zVaeModel, zVaeModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setIsRefinerAvailable,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -59,6 +63,53 @@ export const addModelsLoadedListener = () => {
dispatch(modelChanged(result.data)); dispatch(modelChanged(result.data));
}, },
}); });
startAppListening({
matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info(
{ models: action.payload.entities },
`SDXL Refiner models loaded (${action.payload.ids.length})`
);
const currentModel = getState().sdxl.refinerModel;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setIsRefinerAvailable(false));
return;
}
const result = zMainModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse SDXL Refiner Model'
);
return;
}
dispatch(refinerModelChanged(result.data));
dispatch(setIsRefinerAvailable(true));
},
});
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {

View File

@ -9,10 +9,11 @@ import { memo, useCallback } from 'react';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ sdxl, hotkeys }) => { ({ sdxl, hotkeys }) => {
const { refinerAestheticScore } = sdxl; const { refinerAestheticScore, isRefinerAvailable } = sdxl;
const { shift } = hotkeys; const { shift } = hotkeys;
return { return {
isRefinerAvailable,
refinerAestheticScore, refinerAestheticScore,
shift, shift,
}; };
@ -21,7 +22,8 @@ const selector = createSelector(
); );
const ParamSDXLRefinerAestheticScore = () => { const ParamSDXLRefinerAestheticScore = () => {
const { refinerAestheticScore, shift } = useAppSelector(selector); const { refinerAestheticScore, shift, isRefinerAvailable } =
useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback( const handleChange = useCallback(
@ -48,6 +50,7 @@ const ParamSDXLRefinerAestheticScore = () => {
withReset withReset
withSliderMarks withSliderMarks
isInteger={false} isInteger={false}
isDisabled={!isRefinerAvailable}
/> />
); );
}; };

View File

@ -11,11 +11,12 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ sdxl, ui, hotkeys }) => { ({ sdxl, ui, hotkeys }) => {
const { refinerCFGScale } = sdxl; const { refinerCFGScale, isRefinerAvailable } = sdxl;
const { shouldUseSliders } = ui; const { shouldUseSliders } = ui;
const { shift } = hotkeys; const { shift } = hotkeys;
return { return {
isRefinerAvailable,
refinerCFGScale, refinerCFGScale,
shouldUseSliders, shouldUseSliders,
shift, shift,
@ -25,7 +26,8 @@ const selector = createSelector(
); );
const ParamSDXLRefinerCFGScale = () => { const ParamSDXLRefinerCFGScale = () => {
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector); const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } =
useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -53,6 +55,7 @@ const ParamSDXLRefinerCFGScale = () => {
withReset withReset
withSliderMarks withSliderMarks
isInteger={false} isInteger={false}
isDisabled={!isRefinerAvailable}
/> />
) : ( ) : (
<IAINumberInput <IAINumberInput
@ -64,6 +67,7 @@ const ParamSDXLRefinerCFGScale = () => {
value={refinerCFGScale} value={refinerCFGScale}
isInteger={false} isInteger={false}
numberInputFieldProps={{ textAlign: 'center' }} numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/> />
); );
}; };

View File

@ -11,7 +11,7 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -24,39 +24,37 @@ const ParamSDXLRefinerModelSelect = () => {
const { model } = useAppSelector(selector); const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery();
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!sdxlModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => { forEach(sdxlModels.entities, (model, id) => {
if (!model) { if (!model) {
return; return;
} }
if (['sdxl-refiner'].includes(model.base_model)) { data.push({
data.push({ value: id,
value: id, label: model.model_name,
label: model.model_name, group: MODEL_TYPE_MAP[model.base_model],
group: MODEL_TYPE_MAP[model.base_model], });
});
}
}); });
return data; return data;
}, [mainModels]); }, [sdxlModels]);
// grab the full model entity from the RTK Query cache // grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state? // TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo( const selectedModel = useMemo(
() => () =>
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ?? sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
null, null,
[mainModels?.entities, model] [sdxlModels?.entities, model]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(

View File

@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ ui, sdxl }) => { ({ ui, sdxl }) => {
const { refinerScheduler } = sdxl; const { refinerScheduler, isRefinerAvailable } = sdxl;
const { favoriteSchedulers: enabledSchedulers } = ui; const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
@ -27,6 +27,7 @@ const selector = createSelector(
})).sort((a, b) => a.label.localeCompare(b.label)); })).sort((a, b) => a.label.localeCompare(b.label));
return { return {
isRefinerAvailable,
refinerScheduler, refinerScheduler,
data, data,
}; };
@ -37,7 +38,8 @@ const selector = createSelector(
const ParamSDXLRefinerScheduler = () => { const ParamSDXLRefinerScheduler = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { refinerScheduler, data } = useAppSelector(selector); const { refinerScheduler, data, isRefinerAvailable } =
useAppSelector(selector);
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null) => { (v: string | null) => {
@ -56,6 +58,7 @@ const ParamSDXLRefinerScheduler = () => {
value={refinerScheduler} value={refinerScheduler}
data={data} data={data}
onChange={handleChange} onChange={handleChange}
disabled={!isRefinerAvailable}
/> />
); );
}; };

View File

@ -9,10 +9,11 @@ import { memo, useCallback } from 'react';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ sdxl, hotkeys }) => { ({ sdxl, hotkeys }) => {
const { refinerStart } = sdxl; const { refinerStart, isRefinerAvailable } = sdxl;
const { shift } = hotkeys; const { shift } = hotkeys;
return { return {
isRefinerAvailable,
refinerStart, refinerStart,
shift, shift,
}; };
@ -21,7 +22,7 @@ const selector = createSelector(
); );
const ParamSDXLRefinerStart = () => { const ParamSDXLRefinerStart = () => {
const { refinerStart, shift } = useAppSelector(selector); const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback( const handleChange = useCallback(
@ -48,6 +49,7 @@ const ParamSDXLRefinerStart = () => {
withReset withReset
withSliderMarks withSliderMarks
isInteger={false} isInteger={false}
isDisabled={!isRefinerAvailable}
/> />
); );
}; };

View File

@ -12,10 +12,11 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
({ sdxl, ui }) => { ({ sdxl, ui }) => {
const { refinerSteps } = sdxl; const { refinerSteps, isRefinerAvailable } = sdxl;
const { shouldUseSliders } = ui; const { shouldUseSliders } = ui;
return { return {
isRefinerAvailable,
refinerSteps, refinerSteps,
shouldUseSliders, shouldUseSliders,
}; };
@ -24,7 +25,8 @@ const selector = createSelector(
); );
const ParamSDXLRefinerSteps = () => { const ParamSDXLRefinerSteps = () => {
const { refinerSteps, shouldUseSliders } = useAppSelector(selector); const { refinerSteps, shouldUseSliders, isRefinerAvailable } =
useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -51,6 +53,7 @@ const ParamSDXLRefinerSteps = () => {
withReset withReset
withSliderMarks withSliderMarks
sliderNumberInputProps={{ max: 500 }} sliderNumberInputProps={{ max: 500 }}
isDisabled={!isRefinerAvailable}
/> />
) : ( ) : (
<IAINumberInput <IAINumberInput
@ -61,6 +64,7 @@ const ParamSDXLRefinerSteps = () => {
onChange={handleChange} onChange={handleChange}
value={refinerSteps} value={refinerSteps}
numberInputFieldProps={{ textAlign: 'center' }} numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/> />
); );
}; };

View File

@ -24,7 +24,7 @@ export default function ParamUseSDXLRefiner() {
label="Use Refiner" label="Use Refiner"
isChecked={shouldUseSDXLRefiner} isChecked={shouldUseSDXLRefiner}
onChange={handleUseSDXLRefinerChange} onChange={handleUseSDXLRefinerChange}
isDisabled={isRefinerAvailable} isDisabled={!isRefinerAvailable}
/> />
); );
} }

View File

@ -107,6 +107,9 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
@ -176,6 +179,43 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
{
query: () => ({
url: 'models/',
params: { model_type: 'main', base_models: 'sdxl-refiner' },
}),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'SDXLRefinerModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'SDXLRefinerModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return sdxlRefinerModelsAdapter.setAll(
sdxlRefinerModelsAdapter.getInitialState(),
entities
);
},
}
),
updateMainModels: build.mutation< updateMainModels: build.mutation<
UpdateMainModelResponse, UpdateMainModelResponse,
UpdateMainModelArg UpdateMainModelArg
@ -187,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
ImportMainModelResponse, ImportMainModelResponse,
@ -200,7 +243,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => { query: ({ body }) => {
@ -210,7 +256,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
DeleteMainModelResponse, DeleteMainModelResponse,
@ -222,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE', method: 'DELETE',
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
ConvertMainModelResponse, ConvertMainModelResponse,
@ -235,7 +287,10 @@ export const modelsApi = api.injectEndpoints({
params: params, params: params,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => { query: ({ base_model, body }) => {
@ -245,7 +300,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
query: () => { query: () => {
@ -254,7 +312,10 @@ export const modelsApi = api.injectEndpoints({
method: 'POST', method: 'POST',
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
@ -425,6 +486,7 @@ export const modelsApi = api.injectEndpoints({
export const { export const {
useGetMainModelsQuery, useGetMainModelsQuery,
useGetSDXLRefinerModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,