fix(ui): use model names in badges

This commit is contained in:
psychedelicious 2024-02-21 19:42:36 +11:00
parent 20a56bc757
commit 66ab56246a
8 changed files with 85 additions and 48 deletions

View File

@ -14,14 +14,16 @@ import type { LoRA } from 'features/lora/store/loraSlice';
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type LoRACardProps = {
lora: LoRA;
};
export const LoRACard = memo((props: LoRACardProps) => {
const dispatch = useAppDispatch();
const { lora } = props;
const dispatch = useAppDispatch();
const { data: loraConfig } = useGetModelConfigQuery(lora.key);
const handleChange = useCallback(
(v: number) => {
@ -43,7 +45,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CardHeader>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.key}
{loraConfig?.name ?? lora.key.substring(0, 8)}
</Text>
<Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />

View File

@ -67,6 +67,8 @@ export const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
key: z.string().min(1),
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;
export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>;
@ -141,7 +143,7 @@ export type VAEField = z.infer<typeof zVAEField>;
// #region Control Adapters
export const zControlField = z.object({
image: zImageField,
control_model: zControlNetModelField,
control_model: zModelFieldBase,
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
@ -152,7 +154,7 @@ export type ControlField = z.infer<typeof zControlField>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModelField,
ip_adapter_model: zModelFieldBase,
weight: z.number(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
@ -161,7 +163,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zT2IAdapterField = z.object({
image: zImageField,
t2i_adapter_model: zT2IAdapterModelField,
t2i_adapter_model: zModelFieldBase,
weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),

View File

@ -4,7 +4,7 @@ import {
zControlNetModelField,
zIPAdapterModelField,
zLoRAModelField,
zMainModelField,
zModelIdentifierWithBase,
zSchedulerField,
zSDXLRefinerModelField,
zT2IAdapterModelField,
@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
// #endregion
// #region Model
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
export const zParameterModel = zModelIdentifierWithBase;
export type ParameterModel = z.infer<typeof zParameterModel>;
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
// #endregion

View File

@ -1,5 +1,6 @@
import type { FormLabelProps } from '@invoke-ai/ui-library';
import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
const formLabelProps: FormLabelProps = {
minW: '9.2rem',
@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = {
flexGrow: 1,
};
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (generation.vae) {
// TODO(MM2): Fetch the vae name
let vaeBadge = generation.vae.key;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}
badges.push(vaeBadge);
} else if (generation.vaePrecision === 'fp16') {
badges.push(`VAE ${generation.vaePrecision}`);
}
if (generation.clipSkip) {
badges.push(`Skip ${generation.clipSkip}`);
}
if (generation.cfgRescaleMultiplier) {
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
}
if (generation.seamlessXAxis || generation.seamlessYAxis) {
badges.push('seamless');
}
return badges;
});
export const AdvancedSettingsAccordion = memo(() => {
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (vaeConfig) {
let vaeBadge = vaeConfig.name;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}
badges.push(vaeBadge);
} else if (generation.vaePrecision === 'fp16') {
badges.push(`VAE ${generation.vaePrecision}`);
}
if (generation.clipSkip) {
badges.push(`Skip ${generation.clipSkip}`);
}
if (generation.cfgRescaleMultiplier) {
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
}
if (generation.seamlessXAxis || generation.seamlessYAxis) {
badges.push('seamless');
}
return badges;
}),
[vaeConfig]
);
const badges = useAppSelector(selectBadges);
const { t } = useTranslation();
const { isOpen, onToggle } = useStandaloneAccordionToggle({

View File

@ -12,6 +12,7 @@ import {
} from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { EMPTY_ARRAY } from 'app/store/util';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { filter } from 'lodash-es';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
const formLabelProps: FormLabelProps = {
minW: '4rem',
};
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
// TODO(MM2): fetch model name
if (generation.model) {
accordionBadges.push(generation.model.key);
accordionBadges.push(generation.model.base);
}
return { loraTabBadges, accordionBadges };
});
export const GenerationSettingsAccordion = memo(() => {
const { t } = useTranslation();
const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector);
const modelConfig = useSelectedModelConfig();
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectLoraSlice, (lora) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges };
}),
[modelConfig]
);
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
id: 'generation-settings-advanced',
defaultIsOpen: false,

View File

@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getModelConfig: build.query<AnyModelConfig, string>({
query: (key) => buildModelsUrl(`i/${key}`),
providesTags: (result) => {
const tags: ApiTagDescription[] = ['Model'];
if (result) {
tags.push({ type: 'ModelConfig', id: result.key });
}
return tags;
},
}),
syncModels: build.mutation<SyncModelsResponse, void>({
query: () => {
return {
@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({
});
export const {
useGetModelConfigQuery,
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,

View File

@ -0,0 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key);
export const useSelectedModelConfig = () => {
const key = useAppSelector(selectModelKey);
const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
return modelConfig;
};

View File

@ -26,6 +26,7 @@ export const tagTypes = [
'BatchStatus',
'InvocationCacheStatus',
'Model',
'ModelConfig',
'T2IAdapterModel',
'MainModel',
'VaeModel',