feat(ui): remove isRefinerAvailable state, update refiner node

We can derive `isRefinerAvailable` from the query result (eg are there any refiner models installed). This is a piece of server state, so by using the list models response directly, we can avoid needing to manually keep the client in sync with the server.

Created a `useIsRefinerAvailable()` hook to return this boolean wherever it is needed.

Also updated the main models & refiner models endpoints to only return the appropriate models. Now we don't need to filter the data on these endpoints.
This commit is contained in:
psychedelicious 2023-07-25 22:08:25 +10:00
parent 751c4407e4
commit 8e90f9024d
19 changed files with 228 additions and 46 deletions

View File

@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"},
"type_hints": {"model": "refiner_model"},
},
}

View File

@ -11,7 +11,7 @@ import {
} from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setIsRefinerAvailable,
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
@ -92,7 +92,7 @@ export const addModelsLoadedListener = () => {
if (!firstModel) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setIsRefinerAvailable(false));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
@ -107,7 +107,6 @@ export const addModelsLoadedListener = () => {
}
dispatch(refinerModelChanged(result.data));
dispatch(setIsRefinerAvailable(true));
},
});
startAppListening({

View File

@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent

View File

@ -0,0 +1,115 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
RefinerModelInputFieldTemplate,
RefinerModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const RefinerModelInputFieldComponent = (
props: FieldComponentProps<
RefinerModelInputFieldValue,
RefinerModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
return isLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
</Flex>
);
};
export default memo(RefinerModelInputFieldComponent);

View File

@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip',
VaeField: 'vae',
model: 'model',
refiner_model: 'refiner_model',
vae_model: 'vae_model',
lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model',
description: 'Models are models.',
},
refiner_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Refiner Model',
description: 'Models are models.',
},
vae_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -70,6 +70,7 @@ export type FieldType =
| 'vae'
| 'control'
| 'model'
| 'refiner_model'
| 'vae_model'
| 'lora_model'
| 'controlnet_model'
@ -100,6 +101,7 @@ export type InputFieldValue =
| ControlInputFieldValue
| EnumInputFieldValue
| MainModelInputFieldValue
| RefinerModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
@ -128,6 +130,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| RefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
value?: MainModelParam;
};
export type RefinerModelInputFieldValue = FieldValueBase & {
type: 'refiner_model';
value?: MainModelParam;
};
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
value?: VaeModelParam;
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model';
};
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'refiner_model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'vae_model';

View File

@ -22,6 +22,8 @@ import {
LoRAModelInputFieldTemplate,
ModelInputFieldTemplate,
OutputFieldTemplate,
RefinerModelInputFieldTemplate,
RefinerModelInputFieldValue,
StringInputFieldTemplate,
TypeHints,
UNetInputFieldTemplate,
@ -178,6 +180,21 @@ const buildModelInputFieldTemplate = ({
return template;
};
const buildRefinerModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
const template: RefinerModelInputFieldTemplate = {
...baseField,
type: 'refiner_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVaeModelInputFieldTemplate = ({
schemaObject,
baseField,
@ -492,6 +509,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
if (['refiner_model'].includes(fieldType)) {
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -76,6 +76,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
if (template.type === 'refiner_model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}

View File

@ -40,7 +40,7 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model || ['sdxl-refiner'].includes(model.base_model)) {
if (!model) {
return;
}

View File

@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
const selector = createSelector(
[stateSelector],
({ sdxl, hotkeys }) => {
const { refinerAestheticScore, isRefinerAvailable } = sdxl;
const { refinerAestheticScore } = sdxl;
const { shift } = hotkeys;
return {
isRefinerAvailable,
refinerAestheticScore,
shift,
};
@ -22,8 +22,10 @@ const selector = createSelector(
);
const ParamSDXLRefinerAestheticScore = () => {
const { refinerAestheticScore, shift, isRefinerAvailable } =
useAppSelector(selector);
const { refinerAestheticScore, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleChange = useCallback(

View File

@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -11,12 +12,11 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ sdxl, ui, hotkeys }) => {
const { refinerCFGScale, isRefinerAvailable } = sdxl;
const { refinerCFGScale } = sdxl;
const { shouldUseSliders } = ui;
const { shift } = hotkeys;
return {
isRefinerAvailable,
refinerCFGScale,
shouldUseSliders,
shift,
@ -26,8 +26,8 @@ const selector = createSelector(
);
const ParamSDXLRefinerCFGScale = () => {
const { refinerCFGScale, shouldUseSliders, shift, isRefinerAvailable } =
useAppSelector(selector);
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();

View File

@ -24,16 +24,16 @@ const ParamSDXLRefinerModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: sdxlModels, isLoading } = useGetSDXLRefinerModelsQuery();
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const data = useMemo(() => {
if (!sdxlModels) {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(sdxlModels.entities, (model, id) => {
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
@ -46,15 +46,16 @@ const ParamSDXLRefinerModelSelect = () => {
});
return data;
}, [sdxlModels]);
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
sdxlModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
null,
[sdxlModels?.entities, model]
refinerModels?.entities[
`${model?.base_model}/main/${model?.model_name}`
] ?? null,
[refinerModels?.entities, model]
);
const handleChangeModel = useCallback(

View File

@ -7,6 +7,7 @@ import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
@ -15,7 +16,7 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ ui, sdxl }) => {
const { refinerScheduler, isRefinerAvailable } = sdxl;
const { refinerScheduler } = sdxl;
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
@ -27,7 +28,6 @@ const selector = createSelector(
})).sort((a, b) => a.label.localeCompare(b.label));
return {
isRefinerAvailable,
refinerScheduler,
data,
};
@ -38,9 +38,8 @@ const selector = createSelector(
const ParamSDXLRefinerScheduler = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { refinerScheduler, data, isRefinerAvailable } =
useAppSelector(selector);
const { refinerScheduler, data } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: string | null) => {
if (!v) {

View File

@ -3,17 +3,17 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
const selector = createSelector(
[stateSelector],
({ sdxl, hotkeys }) => {
const { refinerStart, isRefinerAvailable } = sdxl;
const { refinerStart } = sdxl;
const { shift } = hotkeys;
return {
isRefinerAvailable,
refinerStart,
shift,
};
@ -22,9 +22,9 @@ const selector = createSelector(
);
const ParamSDXLRefinerStart = () => {
const { refinerStart, shift, isRefinerAvailable } = useAppSelector(selector);
const { refinerStart, shift } = useAppSelector(selector);
const dispatch = useAppDispatch();
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerStart(v)),
[dispatch]

View File

@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -12,11 +12,10 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ sdxl, ui }) => {
const { refinerSteps, isRefinerAvailable } = sdxl;
const { refinerSteps } = sdxl;
const { shouldUseSliders } = ui;
return {
isRefinerAvailable,
refinerSteps,
shouldUseSliders,
};
@ -25,8 +24,9 @@ const selector = createSelector(
);
const ParamSDXLRefinerSteps = () => {
const { refinerSteps, shouldUseSliders, isRefinerAvailable } =
useAppSelector(selector);
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();

View File

@ -1,6 +1,7 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
@ -8,10 +9,7 @@ export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
);
const isRefinerAvailable = useAppSelector(
(state: RootState) => state.sdxl.isRefinerAvailable
);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();

View File

@ -0,0 +1,11 @@
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable;
};

View File

@ -10,7 +10,6 @@ import { MainModelField } from 'services/api/types';
type SDXLInitialState = {
positiveStylePrompt: PositiveStylePromptSDXLParam;
negativeStylePrompt: NegativeStylePromptSDXLParam;
isRefinerAvailable: boolean;
shouldUseSDXLRefiner: boolean;
refinerModel: MainModelField | null;
refinerSteps: number;
@ -23,7 +22,6 @@ type SDXLInitialState = {
const sdxlInitialState: SDXLInitialState = {
positiveStylePrompt: '',
negativeStylePrompt: '',
isRefinerAvailable: false,
shouldUseSDXLRefiner: false,
refinerModel: null,
refinerSteps: 20,
@ -43,9 +41,6 @@ const sdxlSlice = createSlice({
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.negativeStylePrompt = action.payload;
},
setIsRefinerAvailable: (state, action: PayloadAction<boolean>) => {
state.isRefinerAvailable = action.payload;
},
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
state.shouldUseSDXLRefiner = action.payload;
},
@ -76,7 +71,6 @@ const sdxlSlice = createSlice({
export const {
setPositiveStylePromptSDXL,
setNegativeStylePromptSDXL,
setIsRefinerAvailable,
setShouldUseSDXLRefiner,
refinerModelChanged,
setRefinerSteps,

View File

@ -148,7 +148,15 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
query: () => {
const params = {
model_type: 'main',
base_models: ['sd-1', 'sd-2', 'sdxl'],
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
@ -183,7 +191,7 @@ export const modelsApi = api.injectEndpoints({
{
query: () => ({
url: 'models/',
params: { model_type: 'main', base_models: 'sdxl-refiner' },
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
}),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [