fix(ui): use model names in badges

This commit is contained in:
psychedelicious 2024-02-21 19:42:36 +11:00 committed by Brandon Rising
parent 6577250523
commit 1ced80d492
8 changed files with 85 additions and 48 deletions

View File

@ -15,14 +15,16 @@ import type { LoRA } from 'features/lora/store/loraSlice';
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi'; import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type LoRACardProps = { type LoRACardProps = {
lora: LoRA; lora: LoRA;
}; };
export const LoRACard = memo((props: LoRACardProps) => { export const LoRACard = memo((props: LoRACardProps) => {
const dispatch = useAppDispatch();
const { lora } = props; const { lora } = props;
const dispatch = useAppDispatch();
const { data: loraConfig } = useGetModelConfigQuery(lora.key);
const handleChange = useCallback( const handleChange = useCallback(
(v: number) => { (v: number) => {
@ -44,7 +46,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
<CardHeader> <CardHeader>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}> <Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}> <Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.key} {loraConfig?.name ?? lora.key.substring(0, 8)}
</Text> </Text>
<Flex alignItems="center" gap={2}> <Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} /> <Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />

View File

@ -67,6 +67,8 @@ export const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({ export const zModelIdentifier = z.object({
key: z.string().min(1), key: z.string().min(1),
}); });
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;
export const zModelFieldBase = zModelIdentifier; export const zModelFieldBase = zModelIdentifier;
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
export type BaseModel = z.infer<typeof zBaseModel>; export type BaseModel = z.infer<typeof zBaseModel>;
@ -141,7 +143,7 @@ export type VAEField = z.infer<typeof zVAEField>;
// #region Control Adapters // #region Control Adapters
export const zControlField = z.object({ export const zControlField = z.object({
image: zImageField, image: zImageField,
control_model: zControlNetModelField, control_model: zModelFieldBase,
control_weight: z.union([z.number(), z.array(z.number())]).optional(), control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
@ -152,7 +154,7 @@ export type ControlField = z.infer<typeof zControlField>;
export const zIPAdapterField = z.object({ export const zIPAdapterField = z.object({
image: zImageField, image: zImageField,
ip_adapter_model: zIPAdapterModelField, ip_adapter_model: zModelFieldBase,
weight: z.number(), weight: z.number(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
@ -161,7 +163,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zT2IAdapterField = z.object({ export const zT2IAdapterField = z.object({
image: zImageField, image: zImageField,
t2i_adapter_model: zT2IAdapterModelField, t2i_adapter_model: zModelFieldBase,
weight: z.union([z.number(), z.array(z.number())]).optional(), weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),

View File

@ -4,7 +4,7 @@ import {
zControlNetModelField, zControlNetModelField,
zIPAdapterModelField, zIPAdapterModelField,
zLoRAModelField, zLoRAModelField,
zMainModelField, zModelIdentifierWithBase,
zSchedulerField, zSchedulerField,
zSDXLRefinerModelField, zSDXLRefinerModelField,
zT2IAdapterModelField, zT2IAdapterModelField,
@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
// #endregion // #endregion
// #region Model // #region Model
export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); export const zParameterModel = zModelIdentifierWithBase;
export type ParameterModel = z.infer<typeof zParameterModel>; export type ParameterModel = z.infer<typeof zParameterModel>;
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
// #endregion // #endregion

View File

@ -1,5 +1,6 @@
import type { FormLabelProps } from '@invoke-ai/ui-library'; import type { FormLabelProps } from '@invoke-ai/ui-library';
import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library'; import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier'; import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision'; import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
const formLabelProps: FormLabelProps = { const formLabelProps: FormLabelProps = {
minW: '9.2rem', minW: '9.2rem',
@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = {
flexGrow: 1, flexGrow: 1,
}; };
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (generation.vae) {
// TODO(MM2): Fetch the vae name
let vaeBadge = generation.vae.key;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}
badges.push(vaeBadge);
} else if (generation.vaePrecision === 'fp16') {
badges.push(`VAE ${generation.vaePrecision}`);
}
if (generation.clipSkip) {
badges.push(`Skip ${generation.clipSkip}`);
}
if (generation.cfgRescaleMultiplier) {
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
}
if (generation.seamlessXAxis || generation.seamlessYAxis) {
badges.push('seamless');
}
return badges;
});
export const AdvancedSettingsAccordion = memo(() => { export const AdvancedSettingsAccordion = memo(() => {
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectGenerationSlice, (generation) => {
const badges: (string | number)[] = [];
if (vaeConfig) {
let vaeBadge = vaeConfig.name;
if (generation.vaePrecision === 'fp16') {
vaeBadge += ` ${generation.vaePrecision}`;
}
badges.push(vaeBadge);
} else if (generation.vaePrecision === 'fp16') {
badges.push(`VAE ${generation.vaePrecision}`);
}
if (generation.clipSkip) {
badges.push(`Skip ${generation.clipSkip}`);
}
if (generation.cfgRescaleMultiplier) {
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
}
if (generation.seamlessXAxis || generation.seamlessYAxis) {
badges.push('seamless');
}
return badges;
}),
[vaeConfig]
);
const badges = useAppSelector(selectBadges); const badges = useAppSelector(selectBadges);
const { t } = useTranslation(); const { t } = useTranslation();
const { isOpen, onToggle } = useStandaloneAccordionToggle({ const { isOpen, onToggle } = useStandaloneAccordionToggle({

View File

@ -12,6 +12,7 @@ import {
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { EMPTY_ARRAY } from 'app/store/util';
import { LoRAList } from 'features/lora/components/LoRAList'; import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect'; import LoRASelect from 'features/lora/components/LoRASelect';
import { selectLoraSlice } from 'features/lora/store/loraSlice'; import { selectLoraSlice } from 'features/lora/store/loraSlice';
@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps'; import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect'; import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { filter } from 'lodash-es'; import { filter } from 'lodash-es';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
const formLabelProps: FormLabelProps = { const formLabelProps: FormLabelProps = {
minW: '4rem', minW: '4rem',
}; };
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
// TODO(MM2): fetch model name
if (generation.model) {
accordionBadges.push(generation.model.key);
accordionBadges.push(generation.model.base);
}
return { loraTabBadges, accordionBadges };
});
export const GenerationSettingsAccordion = memo(() => { export const GenerationSettingsAccordion = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector); const modelConfig = useSelectedModelConfig();
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectLoraSlice, (lora) => {
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges };
}),
[modelConfig]
);
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
id: 'generation-settings-advanced', id: 'generation-settings-advanced',
defaultIsOpen: false, defaultIsOpen: false,

View File

@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
getModelConfig: build.query<AnyModelConfig, string>({
query: (key) => buildModelsUrl(`i/${key}`),
providesTags: (result) => {
const tags: ApiTagDescription[] = ['Model'];
if (result) {
tags.push({ type: 'ModelConfig', id: result.key });
}
return tags;
},
}),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
query: () => { query: () => {
return { return {
@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({
}); });
export const { export const {
useGetModelConfigQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,

View File

@ -0,0 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key);
export const useSelectedModelConfig = () => {
const key = useAppSelector(selectModelKey);
const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
return modelConfig;
};

View File

@ -26,6 +26,7 @@ export const tagTypes = [
'BatchStatus', 'BatchStatus',
'InvocationCacheStatus', 'InvocationCacheStatus',
'Model', 'Model',
'ModelConfig',
'T2IAdapterModel', 'T2IAdapterModel',
'MainModel', 'MainModel',
'VaeModel', 'VaeModel',