mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix refiner missing from model manager
Rolled back the earlier split of the refiner model query. Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types. They are provided as constants for simplicity: - ALL_BASE_MODELS - NON_REFINER_BASE_MODELS - REFINER_BASE_MODELS Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
parent
6fa244a343
commit
cbcd416b70
@ -19,7 +19,9 @@ import { startAppListening } from '..';
|
|||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
predicate: (state, action) =>
|
||||||
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
|
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
@ -64,7 +66,9 @@ export const addModelsLoadedListener = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getSDXLRefinerModels.matchFulfilled,
|
predicate: (state, action) =>
|
||||||
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
|
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
|
@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||||
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -13,7 +13,8 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const RefinerModelInputFieldComponent = (
|
const RefinerModelInputFieldComponent = (
|
||||||
@ -27,7 +28,8 @@ const RefinerModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
const { data: refinerModels, isLoading } =
|
||||||
|
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!refinerModels) {
|
if (!refinerModels) {
|
||||||
|
@ -14,6 +14,7 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
|||||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
@ -29,8 +30,10 @@ const ParamMainModelSelect = () => {
|
|||||||
|
|
||||||
const { model } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
|
@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
|
@ -11,7 +11,8 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
|||||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -24,7 +25,8 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
|
|
||||||
const { model } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
const { data: refinerModels, isLoading } =
|
||||||
|
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!refinerModels) {
|
if (!refinerModels) {
|
||||||
|
@ -7,11 +7,11 @@ import {
|
|||||||
SCHEDULER_LABEL_MAP,
|
SCHEDULER_LABEL_MAP,
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
|
@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
|
@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
|
||||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { ChangeEvent } from 'react';
|
import { ChangeEvent } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
export default function ParamUseSDXLRefiner() {
|
export default function ParamUseSDXLRefiner() {
|
||||||
const shouldUseSDXLRefiner = useAppSelector(
|
const shouldUseSDXLRefiner = useAppSelector(
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useIsRefinerAvailable = () => {
|
|
||||||
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
|
|
||||||
selectFromResult: ({ data }) => ({
|
|
||||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
return isRefinerAvailable;
|
|
||||||
};
|
|
@ -16,6 +16,7 @@ import {
|
|||||||
useImportMainModelsMutation,
|
useImportMainModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
export default function FoundModelsList() {
|
export default function FoundModelsList() {
|
||||||
const searchFolder = useAppSelector(
|
const searchFolder = useAppSelector(
|
||||||
@ -24,7 +25,7 @@ export default function FoundModelsList() {
|
|||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
|
|
||||||
// Get paths of models that are already installed
|
// Get paths of models that are already installed
|
||||||
const { data: installedModels } = useGetMainModelsQuery();
|
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||||
|
|
||||||
// Get all model paths from a given directory
|
// Get all model paths from a given directory
|
||||||
const { foundModels, alreadyInstalled, filteredModels } =
|
const { foundModels, alreadyInstalled, filteredModels } =
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
|||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { pickBy } from 'lodash-es';
|
import { pickBy } from 'lodash-es';
|
||||||
import { useMemo, useState } from 'react';
|
import { useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
import {
|
import {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery();
|
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||||
|
|
||||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||||
|
|
||||||
|
@ -8,10 +8,11 @@ import {
|
|||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
export default function ModelManagerPanel() {
|
export default function ModelManagerPanel() {
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||||
const { model } = useGetMainModelsQuery(undefined, {
|
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||||
}),
|
}),
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
type ModelListProps = {
|
type ModelListProps = {
|
||||||
selectedModelId: string | undefined;
|
selectedModelId: string | undefined;
|
||||||
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
const [modelFormatFilter, setModelFormatFilter] =
|
const [modelFormatFilter, setModelFormatFilter] =
|
||||||
useState<ModelFormat>('images');
|
useState<ModelFormat>('images');
|
||||||
|
|
||||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||||
}),
|
}),
|
||||||
|
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType } from './types';
|
||||||
|
|
||||||
|
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||||
|
'sd-1',
|
||||||
|
'sd-2',
|
||||||
|
'sdxl',
|
||||||
|
'sdxl-refiner',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||||
|
'sd-1',
|
||||||
|
'sd-2',
|
||||||
|
'sdxl',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
|||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
|
||||||
});
|
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
@ -147,11 +144,14 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
getMainModels: build.query<
|
||||||
query: () => {
|
EntityState<MainModelConfigEntity>,
|
||||||
|
BaseModelType[]
|
||||||
|
>({
|
||||||
|
query: (base_models) => {
|
||||||
const params = {
|
const params = {
|
||||||
model_type: 'main',
|
model_type: 'main',
|
||||||
base_models: ['sd-1', 'sd-2', 'sdxl'],
|
base_models,
|
||||||
};
|
};
|
||||||
|
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
|
|
||||||
{
|
|
||||||
query: () => ({
|
|
||||||
url: 'models/',
|
|
||||||
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
|
|
||||||
}),
|
|
||||||
providesTags: (result, error, arg) => {
|
|
||||||
const tags: ApiFullTagDescription[] = [
|
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
|
||||||
];
|
|
||||||
|
|
||||||
if (result) {
|
|
||||||
tags.push(
|
|
||||||
...result.ids.map((id) => ({
|
|
||||||
type: 'SDXLRefinerModel' as const,
|
|
||||||
id,
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags;
|
|
||||||
},
|
|
||||||
transformResponse: (
|
|
||||||
response: { models: MainModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<MainModelConfigEntity>(
|
|
||||||
response.models
|
|
||||||
);
|
|
||||||
return sdxlRefinerModelsAdapter.setAll(
|
|
||||||
sdxlRefinerModelsAdapter.getInitialState(),
|
|
||||||
entities
|
|
||||||
);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
updateMainModels: build.mutation<
|
updateMainModels: build.mutation<
|
||||||
UpdateMainModelResponse,
|
UpdateMainModelResponse,
|
||||||
UpdateMainModelArg
|
UpdateMainModelArg
|
||||||
@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetSDXLRefinerModelsQuery,
|
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export const useIsRefinerAvailable = () => {
|
||||||
|
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
return isRefinerAvailable;
|
||||||
|
};
|
Loading…
x
Reference in New Issue
Block a user