mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): remove isRefinerAvailable
state, update refiner node
We can derive `isRefinerAvailable` from the query result (eg are there any refiner models installed). This is a piece of server state, so by using the list models response directly, we can avoid needing to manually keep the client in sync with the server. Created a `useIsRefinerAvailable()` hook to return this boolean wherever it is needed. Also updated the main models & refiner models endpoints to only return the appropriate models. Now we don't need to filter the data on these endpoints.
This commit is contained in:
parent
751c4407e4
commit
8e90f9024d
@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Model Loader",
|
"title": "SDXL Refiner Model Loader",
|
||||||
"tags": ["model", "loader", "sdxl_refiner"],
|
"tags": ["model", "loader", "sdxl_refiner"],
|
||||||
"type_hints": {"model": "model"},
|
"type_hints": {"model": "refiner_model"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import {
|
|||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setIsRefinerAvailable,
|
setShouldUseSDXLRefiner,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
@ -92,7 +92,7 @@ export const addModelsLoadedListener = () => {
|
|||||||
if (!firstModel) {
|
if (!firstModel) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
dispatch(setIsRefinerAvailable(false));
|
dispatch(setShouldUseSDXLRefiner(false));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,6 @@ export const addModelsLoadedListener = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(refinerModelChanged(result.data));
|
dispatch(refinerModelChanged(result.data));
|
||||||
dispatch(setIsRefinerAvailable(true));
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
|||||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||||
|
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
type InputFieldComponentProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||||
|
return (
|
||||||
|
<RefinerModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||||
return (
|
return (
|
||||||
<VaeModelInputFieldComponent
|
<VaeModelInputFieldComponent
|
||||||
|
@ -0,0 +1,115 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
RefinerModelInputFieldTemplate,
|
||||||
|
RefinerModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const RefinerModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
RefinerModelInputFieldValue,
|
||||||
|
RefinerModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!refinerModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(refinerModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [refinerModels]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
refinerModels?.entities[
|
||||||
|
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeModel = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newModel = modelIdToMainModelParam(v);
|
||||||
|
|
||||||
|
if (!newModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: newModel,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return isLoading ? (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
label={t('modelManager.model')}
|
||||||
|
placeholder="Loading..."
|
||||||
|
disabled={true}
|
||||||
|
data={[]}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex w="100%" alignItems="center" gap={2}>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={selectedModel?.id}
|
||||||
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
|
data={data}
|
||||||
|
error={data.length === 0}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
onChange={handleChangeModel}
|
||||||
|
/>
|
||||||
|
<Box mt={7}>
|
||||||
|
<SyncModelsButton iconMode />
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(RefinerModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ClipField: 'clip',
|
ClipField: 'clip',
|
||||||
VaeField: 'vae',
|
VaeField: 'vae',
|
||||||
model: 'model',
|
model: 'model',
|
||||||
|
refiner_model: 'refiner_model',
|
||||||
vae_model: 'vae_model',
|
vae_model: 'vae_model',
|
||||||
lora_model: 'lora_model',
|
lora_model: 'lora_model',
|
||||||
controlnet_model: 'controlnet_model',
|
controlnet_model: 'controlnet_model',
|
||||||
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Model',
|
title: 'Model',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
refiner_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'Refiner Model',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
vae_model: {
|
vae_model: {
|
||||||
color: 'teal',
|
color: 'teal',
|
||||||
colorCssVar: getColorTokenCssVariable('teal'),
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
@ -70,6 +70,7 @@ export type FieldType =
|
|||||||
| 'vae'
|
| 'vae'
|
||||||
| 'control'
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
|
| 'refiner_model'
|
||||||
| 'vae_model'
|
| 'vae_model'
|
||||||
| 'lora_model'
|
| 'lora_model'
|
||||||
| 'controlnet_model'
|
| 'controlnet_model'
|
||||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
|||||||
| ControlInputFieldValue
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| MainModelInputFieldValue
|
| MainModelInputFieldValue
|
||||||
|
| RefinerModelInputFieldValue
|
||||||
| VaeModelInputFieldValue
|
| VaeModelInputFieldValue
|
||||||
| LoRAModelInputFieldValue
|
| LoRAModelInputFieldValue
|
||||||
| ControlNetModelInputFieldValue
|
| ControlNetModelInputFieldValue
|
||||||
@ -128,6 +130,7 @@ export type InputFieldTemplate =
|
|||||||
| ControlInputFieldTemplate
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
|
| RefinerModelInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate
|
| VaeModelInputFieldTemplate
|
||||||
| LoRAModelInputFieldTemplate
|
| LoRAModelInputFieldTemplate
|
||||||
| ControlNetModelInputFieldTemplate
|
| ControlNetModelInputFieldTemplate
|
||||||
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: MainModelParam;
|
value?: MainModelParam;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type RefinerModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'refiner_model';
|
||||||
|
value?: MainModelParam;
|
||||||
|
};
|
||||||
|
|
||||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||||
type: 'vae_model';
|
type: 'vae_model';
|
||||||
value?: VaeModelParam;
|
value?: VaeModelParam;
|
||||||
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'model';
|
type: 'model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'refiner_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string;
|
default: string;
|
||||||
type: 'vae_model';
|
type: 'vae_model';
|
||||||
|
@ -22,6 +22,8 @@ import {
|
|||||||
LoRAModelInputFieldTemplate,
|
LoRAModelInputFieldTemplate,
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
OutputFieldTemplate,
|
OutputFieldTemplate,
|
||||||
|
RefinerModelInputFieldTemplate,
|
||||||
|
RefinerModelInputFieldValue,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
TypeHints,
|
TypeHints,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
@ -178,6 +180,21 @@ const buildModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildRefinerModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
|
||||||
|
const template: RefinerModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'refiner_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildVaeModelInputFieldTemplate = ({
|
const buildVaeModelInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -492,6 +509,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['refiner_model'].includes(fieldType)) {
|
||||||
|
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['vae_model'].includes(fieldType)) {
|
if (['vae_model'].includes(fieldType)) {
|
||||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -76,6 +76,10 @@ export const buildInputFieldValue = (
|
|||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'refiner_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
if (template.type === 'vae_model') {
|
if (template.type === 'vae_model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ const ParamMainModelSelect = () => {
|
|||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(mainModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model || ['sdxl-refiner'].includes(model.base_model)) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, hotkeys }) => {
|
({ sdxl, hotkeys }) => {
|
||||||
const { refinerAestheticScore, isRefinerAvailable } = sdxl;
|
const { refinerAestheticScore } = sdxl;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isRefinerAvailable,
|
|
||||||
refinerAestheticScore,
|
refinerAestheticScore,
|
||||||
shift,
|
shift,
|
||||||
};
|
};
|
||||||
@ -22,8 +22,10 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerAestheticScore = () => {
|
const ParamSDXLRefinerAestheticScore = () => {
|
||||||
const { refinerAestheticScore, shift, isRefinerAvailable } =
|
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -11,12 +12,11 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, ui, hotkeys }) => {
|
({ sdxl, ui, hotkeys }) => {
|
||||||
const { refinerCFGScale, isRefinerAvailable } = sdxl;
|
const { refinerCFGScale } = sdxl;
|
||||||
const { shouldUseSliders } = ui;
|
const { shouldUseSliders } = ui;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isRefinerAvailable,
|
|
||||||
refinerCFGScale,
|
refinerCFGScale,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
shift,
|
shift,
|
||||||
@ -26,8 +26,8 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerCFGScale = () => {
|
const ParamSDXLRefinerCFGScale = () => {
|
||||||
const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } =
|
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
@ -24,16 +24,16 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
|
|
||||||
const { model } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!sdxlModels) {
|
if (!refinerModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(sdxlModels.entities, (model, id) => {
|
forEach(refinerModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -46,15 +46,16 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [sdxlModels]);
|
}, [refinerModels]);
|
||||||
|
|
||||||
// grab the full model entity from the RTK Query cache
|
// grab the full model entity from the RTK Query cache
|
||||||
// TODO: maybe we should just store the full model entity in state?
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() =>
|
() =>
|
||||||
sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
refinerModels?.entities[
|
||||||
null,
|
`${model?.base_model}/main/${model?.model_name}`
|
||||||
[sdxlModels?.entities, model]
|
] ?? null,
|
||||||
|
[refinerModels?.entities, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
SCHEDULER_LABEL_MAP,
|
SCHEDULER_LABEL_MAP,
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -15,7 +16,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ ui, sdxl }) => {
|
({ ui, sdxl }) => {
|
||||||
const { refinerScheduler, isRefinerAvailable } = sdxl;
|
const { refinerScheduler } = sdxl;
|
||||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||||
|
|
||||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||||
@ -27,7 +28,6 @@ const selector = createSelector(
|
|||||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isRefinerAvailable,
|
|
||||||
refinerScheduler,
|
refinerScheduler,
|
||||||
data,
|
data,
|
||||||
};
|
};
|
||||||
@ -38,9 +38,8 @@ const selector = createSelector(
|
|||||||
const ParamSDXLRefinerScheduler = () => {
|
const ParamSDXLRefinerScheduler = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { refinerScheduler, data, isRefinerAvailable } =
|
const { refinerScheduler, data } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
|
@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, hotkeys }) => {
|
({ sdxl, hotkeys }) => {
|
||||||
const { refinerStart, isRefinerAvailable } = sdxl;
|
const { refinerStart } = sdxl;
|
||||||
const { shift } = hotkeys;
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isRefinerAvailable,
|
|
||||||
refinerStart,
|
refinerStart,
|
||||||
shift,
|
shift,
|
||||||
};
|
};
|
||||||
@ -22,9 +22,9 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerStart = () => {
|
const ParamSDXLRefinerStart = () => {
|
||||||
const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector);
|
const { refinerStart, shift } = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: number) => dispatch(setRefinerStart(v)),
|
(v: number) => dispatch(setRefinerStart(v)),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
|
@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -12,11 +12,10 @@ import { useTranslation } from 'react-i18next';
|
|||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ sdxl, ui }) => {
|
({ sdxl, ui }) => {
|
||||||
const { refinerSteps, isRefinerAvailable } = sdxl;
|
const { refinerSteps } = sdxl;
|
||||||
const { shouldUseSliders } = ui;
|
const { shouldUseSliders } = ui;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isRefinerAvailable,
|
|
||||||
refinerSteps,
|
refinerSteps,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
};
|
};
|
||||||
@ -25,8 +24,9 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const ParamSDXLRefinerSteps = () => {
|
const ParamSDXLRefinerSteps = () => {
|
||||||
const { refinerSteps, shouldUseSliders, isRefinerAvailable } =
|
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
|
||||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { ChangeEvent } from 'react';
|
import { ChangeEvent } from 'react';
|
||||||
|
|
||||||
@ -8,10 +9,7 @@ export default function ParamUseSDXLRefiner() {
|
|||||||
const shouldUseSDXLRefiner = useAppSelector(
|
const shouldUseSDXLRefiner = useAppSelector(
|
||||||
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||||
);
|
);
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
const isRefinerAvailable = useAppSelector(
|
|
||||||
(state: RootState) => state.sdxl.isRefinerAvailable
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export const useIsRefinerAvailable = () => {
|
||||||
|
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
return isRefinerAvailable;
|
||||||
|
};
|
@ -10,7 +10,6 @@ import { MainModelField } from 'services/api/types';
|
|||||||
type SDXLInitialState = {
|
type SDXLInitialState = {
|
||||||
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||||
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||||
isRefinerAvailable: boolean;
|
|
||||||
shouldUseSDXLRefiner: boolean;
|
shouldUseSDXLRefiner: boolean;
|
||||||
refinerModel: MainModelField | null;
|
refinerModel: MainModelField | null;
|
||||||
refinerSteps: number;
|
refinerSteps: number;
|
||||||
@ -23,7 +22,6 @@ type SDXLInitialState = {
|
|||||||
const sdxlInitialState: SDXLInitialState = {
|
const sdxlInitialState: SDXLInitialState = {
|
||||||
positiveStylePrompt: '',
|
positiveStylePrompt: '',
|
||||||
negativeStylePrompt: '',
|
negativeStylePrompt: '',
|
||||||
isRefinerAvailable: false,
|
|
||||||
shouldUseSDXLRefiner: false,
|
shouldUseSDXLRefiner: false,
|
||||||
refinerModel: null,
|
refinerModel: null,
|
||||||
refinerSteps: 20,
|
refinerSteps: 20,
|
||||||
@ -43,9 +41,6 @@ const sdxlSlice = createSlice({
|
|||||||
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||||
state.negativeStylePrompt = action.payload;
|
state.negativeStylePrompt = action.payload;
|
||||||
},
|
},
|
||||||
setIsRefinerAvailable: (state, action: PayloadAction<boolean>) => {
|
|
||||||
state.isRefinerAvailable = action.payload;
|
|
||||||
},
|
|
||||||
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldUseSDXLRefiner = action.payload;
|
state.shouldUseSDXLRefiner = action.payload;
|
||||||
},
|
},
|
||||||
@ -76,7 +71,6 @@ const sdxlSlice = createSlice({
|
|||||||
export const {
|
export const {
|
||||||
setPositiveStylePromptSDXL,
|
setPositiveStylePromptSDXL,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
setIsRefinerAvailable,
|
|
||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setRefinerSteps,
|
setRefinerSteps,
|
||||||
|
@ -148,7 +148,15 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
query: () => {
|
||||||
|
const params = {
|
||||||
|
model_type: 'main',
|
||||||
|
base_models: ['sd-1', 'sd-2', 'sdxl'],
|
||||||
|
};
|
||||||
|
|
||||||
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
|
return `models/?${query}`;
|
||||||
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
@ -183,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
{
|
{
|
||||||
query: () => ({
|
query: () => ({
|
||||||
url: 'models/',
|
url: 'models/',
|
||||||
params: { model_type: 'main', base_models: 'sdxl-refiner' },
|
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
|
||||||
}),
|
}),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user