mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix refiner missing from model manager
Rolled back the earlier split of the refiner model query. Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types. They are provided as constants for simplicity: - ALL_BASE_MODELS - NON_REFINER_BASE_MODELS - REFINER_BASE_MODELS Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
parent
6fa244a343
commit
cbcd416b70
@ -19,7 +19,9 @@ import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
@ -64,7 +66,9 @@ export const addModelsLoadedListener = () => {
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled,
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
|
@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
|
@ -13,7 +13,8 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
@ -27,7 +28,8 @@ const RefinerModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
|
@ -14,6 +14,7 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
@ -29,8 +30,10 @@ const ParamMainModelSelect = () => {
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
|
@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
|
@ -11,7 +11,8 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -24,7 +25,8 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
|
@ -7,11 +7,11 @@ import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
|
@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
export default function ParamUseSDXLRefiner() {
|
||||
const shouldUseSDXLRefiner = useAppSelector(
|
||||
|
@ -1,11 +0,0 @@
|
||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
@ -16,6 +16,7 @@ import {
|
||||
useImportMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function FoundModelsList() {
|
||||
const searchFolder = useAppSelector(
|
||||
@ -24,7 +25,7 @@ export default function FoundModelsList() {
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
|
||||
// Get paths of models that are already installed
|
||||
const { data: installedModels } = useGetMainModelsQuery();
|
||||
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
// Get all model paths from a given directory
|
||||
const { foundModels, alreadyInstalled, filteredModels } =
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useMergeMainModelsMutation,
|
||||
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery();
|
||||
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||
|
||||
|
@ -8,10 +8,11 @@ import {
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(undefined, {
|
||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
useGetMainModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
}),
|
||||
|
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@ -0,0 +1,16 @@
|
||||
import { BaseModelType } from './types';
|
||||
|
||||
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
];
|
||||
|
||||
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
];
|
||||
|
||||
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
@ -147,11 +144,14 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => {
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl'],
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
|
||||
{
|
||||
query: () => ({
|
||||
url: 'models/',
|
||||
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
|
||||
}),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'SDXLRefinerModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: MainModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<MainModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return sdxlRefinerModelsAdapter.setAll(
|
||||
sdxlRefinerModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}
|
||||
),
|
||||
updateMainModels: build.mutation<
|
||||
UpdateMainModelResponse,
|
||||
UpdateMainModelArg
|
||||
@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
|
||||
export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetSDXLRefinerModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
Loading…
Reference in New Issue
Block a user