fix(ui): fix refiner missing from model manager

Rolled back the earlier split of the refiner model query.

Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types.

They are provided as constants for simplicity:
- ALL_BASE_MODELS
- NON_REFINER_BASE_MODELS
- REFINER_BASE_MODELS

Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
psychedelicious 2023-07-26 11:04:02 +10:00
parent 6fa244a343
commit cbcd416b70
19 changed files with 72 additions and 75 deletions

View File

@ -19,7 +19,9 @@ import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
@ -64,7 +66,9 @@ export const addModelsLoadedListener = () => {
},
});
startAppListening({
matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled,
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');

View File

@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => {
if (!mainModels) {

View File

@ -13,7 +13,8 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const RefinerModelInputFieldComponent = (
@ -27,7 +28,8 @@ const RefinerModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {

View File

@ -14,6 +14,7 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
@ -29,8 +30,10 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => {
if (!mainModels) {

View File

@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@ -11,7 +11,8 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
@ -24,7 +25,8 @@ const ParamSDXLRefinerModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {

View File

@ -7,11 +7,11 @@ import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,

View File

@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@ -1,9 +1,9 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(

View File

@ -1,11 +0,0 @@
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable;
};

View File

@ -16,6 +16,7 @@ import {
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function FoundModelsList() {
const searchFolder = useAppSelector(
@ -24,7 +25,7 @@ export default function FoundModelsList() {
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery();
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } =

View File

@ -1,5 +1,4 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es';
import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();

View File

@ -8,10 +8,11 @@ import {
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),

View File

@ -11,6 +11,7 @@ import {
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = {
selectedModelId: string | undefined;
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}),

View File

@ -0,0 +1,16 @@
import { BaseModelType } from './types';
export const ALL_BASE_MODELS: BaseModelType[] = [
'sd-1',
'sd-2',
'sdxl',
'sdxl-refiner',
];
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
'sd-1',
'sd-2',
'sdxl',
];
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];

View File

@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -147,11 +144,14 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => {
getMainModels: build.query<
EntityState<MainModelConfigEntity>,
BaseModelType[]
>({
query: (base_models) => {
const params = {
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl'],
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
{
query: () => ({
url: 'models/',
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
}),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'SDXLRefinerModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'SDXLRefinerModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return sdxlRefinerModelsAdapter.setAll(
sdxlRefinerModelsAdapter.getInitialState(),
entities
);
},
}
),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({
export const {
useGetMainModelsQuery,
useGetSDXLRefinerModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,

View File

@ -0,0 +1,12 @@
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable;
};