mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): use model names in badges
This commit is contained in:
parent
20a56bc757
commit
66ab56246a
@ -14,14 +14,16 @@ import type { LoRA } from 'features/lora/store/loraSlice';
|
|||||||
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type LoRACardProps = {
|
type LoRACardProps = {
|
||||||
lora: LoRA;
|
lora: LoRA;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const LoRACard = memo((props: LoRACardProps) => {
|
export const LoRACard = memo((props: LoRACardProps) => {
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { lora } = props;
|
const { lora } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { data: loraConfig } = useGetModelConfigQuery(lora.key);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -43,7 +45,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
<CardHeader>
|
<CardHeader>
|
||||||
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||||
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||||
{lora.key}
|
{loraConfig?.name ?? lora.key.substring(0, 8)}
|
||||||
</Text>
|
</Text>
|
||||||
<Flex alignItems="center" gap={2}>
|
<Flex alignItems="center" gap={2}>
|
||||||
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||||
|
@ -67,6 +67,8 @@ export const zModelName = z.string().min(3);
|
|||||||
export const zModelIdentifier = z.object({
|
export const zModelIdentifier = z.object({
|
||||||
key: z.string().min(1),
|
key: z.string().min(1),
|
||||||
});
|
});
|
||||||
|
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||||
|
zModelIdentifier.safeParse(field).success;
|
||||||
export const zModelFieldBase = zModelIdentifier;
|
export const zModelFieldBase = zModelIdentifier;
|
||||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
||||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||||
@ -141,7 +143,7 @@ export type VAEField = z.infer<typeof zVAEField>;
|
|||||||
// #region Control Adapters
|
// #region Control Adapters
|
||||||
export const zControlField = z.object({
|
export const zControlField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
control_model: zControlNetModelField,
|
control_model: zModelFieldBase,
|
||||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
@ -152,7 +154,7 @@ export type ControlField = z.infer<typeof zControlField>;
|
|||||||
|
|
||||||
export const zIPAdapterField = z.object({
|
export const zIPAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
ip_adapter_model: zIPAdapterModelField,
|
ip_adapter_model: zModelFieldBase,
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
@ -161,7 +163,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
|||||||
|
|
||||||
export const zT2IAdapterField = z.object({
|
export const zT2IAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
t2i_adapter_model: zT2IAdapterModelField,
|
t2i_adapter_model: zModelFieldBase,
|
||||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
|
@ -4,7 +4,7 @@ import {
|
|||||||
zControlNetModelField,
|
zControlNetModelField,
|
||||||
zIPAdapterModelField,
|
zIPAdapterModelField,
|
||||||
zLoRAModelField,
|
zLoRAModelField,
|
||||||
zMainModelField,
|
zModelIdentifierWithBase,
|
||||||
zSchedulerField,
|
zSchedulerField,
|
||||||
zSDXLRefinerModelField,
|
zSDXLRefinerModelField,
|
||||||
zT2IAdapterModelField,
|
zT2IAdapterModelField,
|
||||||
@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model
|
// #region Model
|
||||||
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
|
export const zParameterModel = zModelIdentifierWithBase;
|
||||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||||
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
|
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||||
import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
||||||
@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA
|
|||||||
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
|
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
import { memo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const formLabelProps: FormLabelProps = {
|
const formLabelProps: FormLabelProps = {
|
||||||
minW: '9.2rem',
|
minW: '9.2rem',
|
||||||
@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = {
|
|||||||
flexGrow: 1,
|
flexGrow: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
|
||||||
const badges: (string | number)[] = [];
|
|
||||||
if (generation.vae) {
|
|
||||||
// TODO(MM2): Fetch the vae name
|
|
||||||
let vaeBadge = generation.vae.key;
|
|
||||||
if (generation.vaePrecision === 'fp16') {
|
|
||||||
vaeBadge += ` ${generation.vaePrecision}`;
|
|
||||||
}
|
|
||||||
badges.push(vaeBadge);
|
|
||||||
} else if (generation.vaePrecision === 'fp16') {
|
|
||||||
badges.push(`VAE ${generation.vaePrecision}`);
|
|
||||||
}
|
|
||||||
if (generation.clipSkip) {
|
|
||||||
badges.push(`Skip ${generation.clipSkip}`);
|
|
||||||
}
|
|
||||||
if (generation.cfgRescaleMultiplier) {
|
|
||||||
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
|
|
||||||
}
|
|
||||||
if (generation.seamlessXAxis || generation.seamlessYAxis) {
|
|
||||||
badges.push('seamless');
|
|
||||||
}
|
|
||||||
return badges;
|
|
||||||
});
|
|
||||||
|
|
||||||
export const AdvancedSettingsAccordion = memo(() => {
|
export const AdvancedSettingsAccordion = memo(() => {
|
||||||
|
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
|
||||||
|
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||||
|
const selectBadges = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||||
|
const badges: (string | number)[] = [];
|
||||||
|
if (vaeConfig) {
|
||||||
|
let vaeBadge = vaeConfig.name;
|
||||||
|
if (generation.vaePrecision === 'fp16') {
|
||||||
|
vaeBadge += ` ${generation.vaePrecision}`;
|
||||||
|
}
|
||||||
|
badges.push(vaeBadge);
|
||||||
|
} else if (generation.vaePrecision === 'fp16') {
|
||||||
|
badges.push(`VAE ${generation.vaePrecision}`);
|
||||||
|
}
|
||||||
|
if (generation.clipSkip) {
|
||||||
|
badges.push(`Skip ${generation.clipSkip}`);
|
||||||
|
}
|
||||||
|
if (generation.cfgRescaleMultiplier) {
|
||||||
|
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
|
||||||
|
}
|
||||||
|
if (generation.seamlessXAxis || generation.seamlessYAxis) {
|
||||||
|
badges.push('seamless');
|
||||||
|
}
|
||||||
|
return badges;
|
||||||
|
}),
|
||||||
|
[vaeConfig]
|
||||||
|
);
|
||||||
const badges = useAppSelector(selectBadges);
|
const badges = useAppSelector(selectBadges);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { EMPTY_ARRAY } from 'app/store/util';
|
||||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||||
@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
|||||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
|
||||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
import { filter } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
import { memo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||||
|
|
||||||
const formLabelProps: FormLabelProps = {
|
const formLabelProps: FormLabelProps = {
|
||||||
minW: '4rem',
|
minW: '4rem',
|
||||||
};
|
};
|
||||||
|
|
||||||
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
|
|
||||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
|
||||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
|
|
||||||
const accordionBadges: (string | number)[] = [];
|
|
||||||
// TODO(MM2): fetch model name
|
|
||||||
if (generation.model) {
|
|
||||||
accordionBadges.push(generation.model.key);
|
|
||||||
accordionBadges.push(generation.model.base);
|
|
||||||
}
|
|
||||||
|
|
||||||
return { loraTabBadges, accordionBadges };
|
|
||||||
});
|
|
||||||
|
|
||||||
export const GenerationSettingsAccordion = memo(() => {
|
export const GenerationSettingsAccordion = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector);
|
const modelConfig = useSelectedModelConfig();
|
||||||
|
const selectBadges = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||||
|
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||||
|
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
|
||||||
|
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||||
|
return { loraTabBadges, accordionBadges };
|
||||||
|
}),
|
||||||
|
[modelConfig]
|
||||||
|
);
|
||||||
|
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
||||||
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
||||||
id: 'generation-settings-advanced',
|
id: 'generation-settings-advanced',
|
||||||
defaultIsOpen: false,
|
defaultIsOpen: false,
|
||||||
|
@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
|
getModelConfig: build.query<AnyModelConfig, string>({
|
||||||
|
query: (key) => buildModelsUrl(`i/${key}`),
|
||||||
|
providesTags: (result) => {
|
||||||
|
const tags: ApiTagDescription[] = ['Model'];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push({ type: 'ModelConfig', id: result.key });
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
}),
|
||||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||||
query: () => {
|
query: () => {
|
||||||
return {
|
return {
|
||||||
@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
|
useGetModelConfigQuery,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetIPAdapterModelsQuery,
|
useGetIPAdapterModelsQuery,
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key);
|
||||||
|
|
||||||
|
export const useSelectedModelConfig = () => {
|
||||||
|
const key = useAppSelector(selectModelKey);
|
||||||
|
const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
|
||||||
|
|
||||||
|
return modelConfig;
|
||||||
|
};
|
@ -26,6 +26,7 @@ export const tagTypes = [
|
|||||||
'BatchStatus',
|
'BatchStatus',
|
||||||
'InvocationCacheStatus',
|
'InvocationCacheStatus',
|
||||||
'Model',
|
'Model',
|
||||||
|
'ModelConfig',
|
||||||
'T2IAdapterModel',
|
'T2IAdapterModel',
|
||||||
'MainModel',
|
'MainModel',
|
||||||
'VaeModel',
|
'VaeModel',
|
||||||
|
Loading…
Reference in New Issue
Block a user