mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Move SDXL Refiner to own route & set appropriate disabled statuses
This commit is contained in:
parent
8d1b8179af
commit
5202610160
@ -9,6 +9,10 @@ import {
|
|||||||
zMainModel,
|
zMainModel,
|
||||||
zVaeModel,
|
zVaeModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import {
|
||||||
|
refinerModelChanged,
|
||||||
|
setIsRefinerAvailable,
|
||||||
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
@ -59,6 +63,53 @@ export const addModelsLoadedListener = () => {
|
|||||||
dispatch(modelChanged(result.data));
|
dispatch(modelChanged(result.data));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
startAppListening({
|
||||||
|
matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled,
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
|
const log = logger('models');
|
||||||
|
log.info(
|
||||||
|
{ models: action.payload.entities },
|
||||||
|
`SDXL Refiner models loaded (${action.payload.ids.length})`
|
||||||
|
);
|
||||||
|
|
||||||
|
const currentModel = getState().sdxl.refinerModel;
|
||||||
|
|
||||||
|
const isCurrentModelAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === currentModel?.model_name &&
|
||||||
|
m?.base_model === currentModel?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModelId = action.payload.ids[0];
|
||||||
|
const firstModel = action.payload.entities[firstModelId];
|
||||||
|
|
||||||
|
if (!firstModel) {
|
||||||
|
// No models loaded at all
|
||||||
|
dispatch(refinerModelChanged(null));
|
||||||
|
dispatch(setIsRefinerAvailable(false));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zMainModel.safeParse(firstModel);
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
log.error(
|
||||||
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse SDXL Refiner Model'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(refinerModelChanged(result.data));
|
||||||
|
dispatch(setIsRefinerAvailable(true));
|
||||||
|
},
|
||||||
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
@ -9,10 +9,11 @@ import { memo, useCallback } from 'react';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, hotkeys }) => {
|
({ sdxl, hotkeys }) => {
|
||||||
const { refinerAestheticScore } = sdxl;
|
const { refinerAestheticScore, isRefinerAvailable } = sdxl;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
isRefinerAvailable,
|
||||||
refinerAestheticScore,
|
refinerAestheticScore,
|
||||||
shift,
|
shift,
|
||||||
};
|
};
|
||||||
@ -21,7 +22,8 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerAestheticScore = () => {
|
const ParamSDXLRefinerAestheticScore = () => {
|
||||||
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
const { refinerAestheticScore, shift, isRefinerAvailable } =
|
||||||
|
useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
@ -48,6 +50,7 @@ const ParamSDXLRefinerAestheticScore = () => {
|
|||||||
withReset
|
withReset
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isInteger={false}
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -11,11 +11,12 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, ui, hotkeys }) => {
|
({ sdxl, ui, hotkeys }) => {
|
||||||
const { refinerCFGScale } = sdxl;
|
const { refinerCFGScale, isRefinerAvailable } = sdxl;
|
||||||
const { shouldUseSliders } = ui;
|
const { shouldUseSliders } = ui;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
isRefinerAvailable,
|
||||||
refinerCFGScale,
|
refinerCFGScale,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
shift,
|
shift,
|
||||||
@ -25,7 +26,8 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerCFGScale = () => {
|
const ParamSDXLRefinerCFGScale = () => {
|
||||||
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } =
|
||||||
|
useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -53,6 +55,7 @@ const ParamSDXLRefinerCFGScale = () => {
|
|||||||
withReset
|
withReset
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isInteger={false}
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<IAINumberInput
|
<IAINumberInput
|
||||||
@ -64,6 +67,7 @@ const ParamSDXLRefinerCFGScale = () => {
|
|||||||
value={refinerCFGScale}
|
value={refinerCFGScale}
|
||||||
isInteger={false}
|
isInteger={false}
|
||||||
numberInputFieldProps={{ textAlign: 'center' }}
|
numberInputFieldProps={{ textAlign: 'center' }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -11,7 +11,7 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
|||||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -24,39 +24,37 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
|
|
||||||
const { model } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!sdxlModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(mainModels.entities, (model, id) => {
|
forEach(sdxlModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (['sdxl-refiner'].includes(model.base_model)) {
|
data.push({
|
||||||
data.push({
|
value: id,
|
||||||
value: id,
|
label: model.model_name,
|
||||||
label: model.model_name,
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
group: MODEL_TYPE_MAP[model.base_model],
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [mainModels]);
|
}, [sdxlModels]);
|
||||||
|
|
||||||
// grab the full model entity from the RTK Query cache
|
// grab the full model entity from the RTK Query cache
|
||||||
// TODO: maybe we should just store the full model entity in state?
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() =>
|
() =>
|
||||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||||
null,
|
null,
|
||||||
[mainModels?.entities, model]
|
[sdxlModels?.entities, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
|
@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ ui, sdxl }) => {
|
({ ui, sdxl }) => {
|
||||||
const { refinerScheduler } = sdxl;
|
const { refinerScheduler, isRefinerAvailable } = sdxl;
|
||||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||||
|
|
||||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||||
@ -27,6 +27,7 @@ const selector = createSelector(
|
|||||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
isRefinerAvailable,
|
||||||
refinerScheduler,
|
refinerScheduler,
|
||||||
data,
|
data,
|
||||||
};
|
};
|
||||||
@ -37,7 +38,8 @@ const selector = createSelector(
|
|||||||
const ParamSDXLRefinerScheduler = () => {
|
const ParamSDXLRefinerScheduler = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { refinerScheduler, data } = useAppSelector(selector);
|
const { refinerScheduler, data, isRefinerAvailable } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
@ -56,6 +58,7 @@ const ParamSDXLRefinerScheduler = () => {
|
|||||||
value={refinerScheduler}
|
value={refinerScheduler}
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
disabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -9,10 +9,11 @@ import { memo, useCallback } from 'react';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, hotkeys }) => {
|
({ sdxl, hotkeys }) => {
|
||||||
const { refinerStart } = sdxl;
|
const { refinerStart, isRefinerAvailable } = sdxl;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
isRefinerAvailable,
|
||||||
refinerStart,
|
refinerStart,
|
||||||
shift,
|
shift,
|
||||||
};
|
};
|
||||||
@ -21,7 +22,7 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerStart = () => {
|
const ParamSDXLRefinerStart = () => {
|
||||||
const { refinerStart, shift } = useAppSelector(selector);
|
const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
@ -48,6 +49,7 @@ const ParamSDXLRefinerStart = () => {
|
|||||||
withReset
|
withReset
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isInteger={false}
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -12,10 +12,11 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, ui }) => {
|
({ sdxl, ui }) => {
|
||||||
const { refinerSteps } = sdxl;
|
const { refinerSteps, isRefinerAvailable } = sdxl;
|
||||||
const { shouldUseSliders } = ui;
|
const { shouldUseSliders } = ui;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
isRefinerAvailable,
|
||||||
refinerSteps,
|
refinerSteps,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
};
|
};
|
||||||
@ -24,7 +25,8 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerSteps = () => {
|
const ParamSDXLRefinerSteps = () => {
|
||||||
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
const { refinerSteps, shouldUseSliders, isRefinerAvailable } =
|
||||||
|
useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -51,6 +53,7 @@ const ParamSDXLRefinerSteps = () => {
|
|||||||
withReset
|
withReset
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
sliderNumberInputProps={{ max: 500 }}
|
sliderNumberInputProps={{ max: 500 }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<IAINumberInput
|
<IAINumberInput
|
||||||
@ -61,6 +64,7 @@ const ParamSDXLRefinerSteps = () => {
|
|||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
value={refinerSteps}
|
value={refinerSteps}
|
||||||
numberInputFieldProps={{ textAlign: 'center' }}
|
numberInputFieldProps={{ textAlign: 'center' }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -24,7 +24,7 @@ export default function ParamUseSDXLRefiner() {
|
|||||||
label="Use Refiner"
|
label="Use Refiner"
|
||||||
isChecked={shouldUseSDXLRefiner}
|
isChecked={shouldUseSDXLRefiner}
|
||||||
onChange={handleUseSDXLRefinerChange}
|
onChange={handleUseSDXLRefinerChange}
|
||||||
isDisabled={isRefinerAvailable}
|
isDisabled={!isRefinerAvailable}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -107,6 +107,9 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
|||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
|
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
|
});
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
@ -176,6 +179,43 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
|
||||||
|
{
|
||||||
|
query: () => ({
|
||||||
|
url: 'models/',
|
||||||
|
params: { model_type: 'main', base_models: 'sdxl-refiner' },
|
||||||
|
}),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'SDXLRefinerModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: MainModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<MainModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return sdxlRefinerModelsAdapter.setAll(
|
||||||
|
sdxlRefinerModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
updateMainModels: build.mutation<
|
updateMainModels: build.mutation<
|
||||||
UpdateMainModelResponse,
|
UpdateMainModelResponse,
|
||||||
UpdateMainModelArg
|
UpdateMainModelArg
|
||||||
@ -187,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
importMainModels: build.mutation<
|
importMainModels: build.mutation<
|
||||||
ImportMainModelResponse,
|
ImportMainModelResponse,
|
||||||
@ -200,7 +243,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||||
query: ({ body }) => {
|
query: ({ body }) => {
|
||||||
@ -210,7 +256,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
deleteMainModels: build.mutation<
|
deleteMainModels: build.mutation<
|
||||||
DeleteMainModelResponse,
|
DeleteMainModelResponse,
|
||||||
@ -222,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<
|
convertMainModels: build.mutation<
|
||||||
ConvertMainModelResponse,
|
ConvertMainModelResponse,
|
||||||
@ -235,7 +287,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
params: params,
|
params: params,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||||
query: ({ base_model, body }) => {
|
query: ({ base_model, body }) => {
|
||||||
@ -245,7 +300,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||||
query: () => {
|
query: () => {
|
||||||
@ -254,7 +312,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
@ -425,6 +486,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
|
useGetSDXLRefinerModelsQuery,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
|
Loading…
Reference in New Issue
Block a user