mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): remove isRefinerAvailable
state, update refiner node
We can derive `isRefinerAvailable` from the query result (eg are there any refiner models installed). This is a piece of server state, so by using the list models response directly, we can avoid needing to manually keep the client in sync with the server. Created a `useIsRefinerAvailable()` hook to return this boolean wherever it is needed. Also updated the main models & refiner models endpoints to only return the appropriate models. Now we don't need to filter the data on these endpoints.
This commit is contained in:
parent
751c4407e4
commit
8e90f9024d
@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "model"},
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ import {
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setIsRefinerAvailable,
|
||||
setShouldUseSDXLRefiner,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
@ -92,7 +92,7 @@ export const addModelsLoadedListener = () => {
|
||||
if (!firstModel) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
dispatch(setIsRefinerAvailable(false));
|
||||
dispatch(setShouldUseSDXLRefiner(false));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -107,7 +107,6 @@ export const addModelsLoadedListener = () => {
|
||||
}
|
||||
|
||||
dispatch(refinerModelChanged(result.data));
|
||||
dispatch(setIsRefinerAvailable(true));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
|
@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||
return (
|
||||
<RefinerModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
|
@ -0,0 +1,115 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
RefinerModelInputFieldTemplate,
|
||||
RefinerModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
RefinerModelInputFieldValue,
|
||||
RefinerModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
refinerModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ClipField: 'clip',
|
||||
VaeField: 'vae',
|
||||
model: 'model',
|
||||
refiner_model: 'refiner_model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
refiner_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Refiner Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
vae_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
|
@ -70,6 +70,7 @@ export type FieldType =
|
||||
| 'vae'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'refiner_model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| MainModelInputFieldValue
|
||||
| RefinerModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
@ -128,6 +130,7 @@ export type InputFieldTemplate =
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| RefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldValue = FieldValueBase & {
|
||||
type: 'refiner_model';
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: VaeModelParam;
|
||||
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'refiner_model';
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'vae_model';
|
||||
|
@ -22,6 +22,8 @@ import {
|
||||
LoRAModelInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
RefinerModelInputFieldTemplate,
|
||||
RefinerModelInputFieldValue,
|
||||
StringInputFieldTemplate,
|
||||
TypeHints,
|
||||
UNetInputFieldTemplate,
|
||||
@ -178,6 +180,21 @@ const buildModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
|
||||
const template: RefinerModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'refiner_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVaeModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -492,6 +509,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['refiner_model'].includes(fieldType)) {
|
||||
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -76,6 +76,10 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'refiner_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ const ParamMainModelSelect = () => {
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model || ['sdxl-refiner'].includes(model.base_model)) {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, hotkeys }) => {
|
||||
const { refinerAestheticScore, isRefinerAvailable } = sdxl;
|
||||
const { refinerAestheticScore } = sdxl;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
isRefinerAvailable,
|
||||
refinerAestheticScore,
|
||||
shift,
|
||||
};
|
||||
@ -22,8 +22,10 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerAestheticScore = () => {
|
||||
const { refinerAestheticScore, shift, isRefinerAvailable } =
|
||||
useAppSelector(selector);
|
||||
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
||||
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = useCallback(
|
||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -11,12 +12,11 @@ import { useTranslation } from 'react-i18next';
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui, hotkeys }) => {
|
||||
const { refinerCFGScale, isRefinerAvailable } = sdxl;
|
||||
const { refinerCFGScale } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
isRefinerAvailable,
|
||||
refinerCFGScale,
|
||||
shouldUseSliders,
|
||||
shift,
|
||||
@ -26,8 +26,8 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerCFGScale = () => {
|
||||
const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } =
|
||||
useAppSelector(selector);
|
||||
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
@ -24,16 +24,16 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!sdxlModels) {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(sdxlModels.entities, (model, id) => {
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
@ -46,15 +46,16 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [sdxlModels]);
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||
null,
|
||||
[sdxlModels?.entities, model]
|
||||
refinerModels?.entities[
|
||||
`${model?.base_model}/main/${model?.model_name}`
|
||||
] ?? null,
|
||||
[refinerModels?.entities, model]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
@ -15,7 +16,7 @@ import { useTranslation } from 'react-i18next';
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ ui, sdxl }) => {
|
||||
const { refinerScheduler, isRefinerAvailable } = sdxl;
|
||||
const { refinerScheduler } = sdxl;
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
@ -27,7 +28,6 @@ const selector = createSelector(
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
isRefinerAvailable,
|
||||
refinerScheduler,
|
||||
data,
|
||||
};
|
||||
@ -38,9 +38,8 @@ const selector = createSelector(
|
||||
const ParamSDXLRefinerScheduler = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { refinerScheduler, data, isRefinerAvailable } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { refinerScheduler, data } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
|
@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, hotkeys }) => {
|
||||
const { refinerStart, isRefinerAvailable } = sdxl;
|
||||
const { refinerStart } = sdxl;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
isRefinerAvailable,
|
||||
refinerStart,
|
||||
shift,
|
||||
};
|
||||
@ -22,9 +22,9 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerStart = () => {
|
||||
const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector);
|
||||
const { refinerStart, shift } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerStart(v)),
|
||||
[dispatch]
|
||||
|
@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -12,11 +12,10 @@ import { useTranslation } from 'react-i18next';
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui }) => {
|
||||
const { refinerSteps, isRefinerAvailable } = sdxl;
|
||||
const { refinerSteps } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
|
||||
return {
|
||||
isRefinerAvailable,
|
||||
refinerSteps,
|
||||
shouldUseSliders,
|
||||
};
|
||||
@ -25,8 +24,9 @@ const selector = createSelector(
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerSteps = () => {
|
||||
const { refinerSteps, shouldUseSliders, isRefinerAvailable } =
|
||||
useAppSelector(selector);
|
||||
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
|
||||
@ -8,10 +9,7 @@ export default function ParamUseSDXLRefiner() {
|
||||
const shouldUseSDXLRefiner = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||
);
|
||||
|
||||
const isRefinerAvailable = useAppSelector(
|
||||
(state: RootState) => state.sdxl.isRefinerAvailable
|
||||
);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
|
@ -0,0 +1,11 @@
|
||||
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
@ -10,7 +10,6 @@ import { MainModelField } from 'services/api/types';
|
||||
type SDXLInitialState = {
|
||||
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||
isRefinerAvailable: boolean;
|
||||
shouldUseSDXLRefiner: boolean;
|
||||
refinerModel: MainModelField | null;
|
||||
refinerSteps: number;
|
||||
@ -23,7 +22,6 @@ type SDXLInitialState = {
|
||||
const sdxlInitialState: SDXLInitialState = {
|
||||
positiveStylePrompt: '',
|
||||
negativeStylePrompt: '',
|
||||
isRefinerAvailable: false,
|
||||
shouldUseSDXLRefiner: false,
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
@ -43,9 +41,6 @@ const sdxlSlice = createSlice({
|
||||
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||
state.negativeStylePrompt = action.payload;
|
||||
},
|
||||
setIsRefinerAvailable: (state, action: PayloadAction<boolean>) => {
|
||||
state.isRefinerAvailable = action.payload;
|
||||
},
|
||||
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldUseSDXLRefiner = action.payload;
|
||||
},
|
||||
@ -76,7 +71,6 @@ const sdxlSlice = createSlice({
|
||||
export const {
|
||||
setPositiveStylePromptSDXL,
|
||||
setNegativeStylePromptSDXL,
|
||||
setIsRefinerAvailable,
|
||||
setShouldUseSDXLRefiner,
|
||||
refinerModelChanged,
|
||||
setRefinerSteps,
|
||||
|
@ -148,7 +148,15 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||
query: () => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl'],
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
@ -183,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
{
|
||||
query: () => ({
|
||||
url: 'models/',
|
||||
params: { model_type: 'main', base_models: 'sdxl-refiner' },
|
||||
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
|
||||
}),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
|
Loading…
Reference in New Issue
Block a user