mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): use model names in badges
This commit is contained in:
parent
20a56bc757
commit
66ab56246a
@ -14,14 +14,16 @@ import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type LoRACardProps = {
|
||||
lora: LoRA;
|
||||
};
|
||||
|
||||
export const LoRACard = memo((props: LoRACardProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { lora } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: loraConfig } = useGetModelConfigQuery(lora.key);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
@ -43,7 +45,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
<CardHeader>
|
||||
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||
{lora.key}
|
||||
{loraConfig?.name ?? lora.key.substring(0, 8)}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||
|
@ -67,6 +67,8 @@ export const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
key: z.string().min(1),
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
export const zModelFieldBase = zModelIdentifier;
|
||||
export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel });
|
||||
export type BaseModel = z.infer<typeof zBaseModel>;
|
||||
@ -141,7 +143,7 @@ export type VAEField = z.infer<typeof zVAEField>;
|
||||
// #region Control Adapters
|
||||
export const zControlField = z.object({
|
||||
image: zImageField,
|
||||
control_model: zControlNetModelField,
|
||||
control_model: zModelFieldBase,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
@ -152,7 +154,7 @@ export type ControlField = z.infer<typeof zControlField>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zIPAdapterModelField,
|
||||
ip_adapter_model: zModelFieldBase,
|
||||
weight: z.number(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
@ -161,7 +163,7 @@ export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zT2IAdapterField = z.object({
|
||||
image: zImageField,
|
||||
t2i_adapter_model: zT2IAdapterModelField,
|
||||
t2i_adapter_model: zModelFieldBase,
|
||||
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
|
@ -4,7 +4,7 @@ import {
|
||||
zControlNetModelField,
|
||||
zIPAdapterModelField,
|
||||
zLoRAModelField,
|
||||
zMainModelField,
|
||||
zModelIdentifierWithBase,
|
||||
zSchedulerField,
|
||||
zSDXLRefinerModelField,
|
||||
zT2IAdapterModelField,
|
||||
@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati
|
||||
// #endregion
|
||||
|
||||
// #region Model
|
||||
export const zParameterModel = zMainModelField.extend({ base: zBaseModel });
|
||||
export const zParameterModel = zModelIdentifierWithBase;
|
||||
export type ParameterModel = z.infer<typeof zParameterModel>;
|
||||
export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success;
|
||||
// #endregion
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
||||
@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA
|
||||
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '9.2rem',
|
||||
@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = {
|
||||
flexGrow: 1,
|
||||
};
|
||||
|
||||
const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
const badges: (string | number)[] = [];
|
||||
if (generation.vae) {
|
||||
// TODO(MM2): Fetch the vae name
|
||||
let vaeBadge = generation.vae.key;
|
||||
if (generation.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${generation.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (generation.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${generation.vaePrecision}`);
|
||||
}
|
||||
if (generation.clipSkip) {
|
||||
badges.push(`Skip ${generation.clipSkip}`);
|
||||
}
|
||||
if (generation.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (generation.seamlessXAxis || generation.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
return badges;
|
||||
});
|
||||
|
||||
export const AdvancedSettingsAccordion = memo(() => {
|
||||
const vaeKey = useAppSelector((state) => state.generation.vae?.key);
|
||||
const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
const badges: (string | number)[] = [];
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (generation.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${generation.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (generation.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${generation.vaePrecision}`);
|
||||
}
|
||||
if (generation.clipSkip) {
|
||||
badges.push(`Skip ${generation.clipSkip}`);
|
||||
}
|
||||
if (generation.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${generation.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (generation.seamlessXAxis || generation.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
return badges;
|
||||
}),
|
||||
[vaeConfig]
|
||||
);
|
||||
const badges = useAppSelector(selectBadges);
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '4rem',
|
||||
};
|
||||
|
||||
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : [];
|
||||
const accordionBadges: (string | number)[] = [];
|
||||
// TODO(MM2): fetch model name
|
||||
if (generation.model) {
|
||||
accordionBadges.push(generation.model.key);
|
||||
accordionBadges.push(generation.model.base);
|
||||
}
|
||||
|
||||
return { loraTabBadges, accordionBadges };
|
||||
});
|
||||
|
||||
export const GenerationSettingsAccordion = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector);
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY;
|
||||
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
|
||||
return { loraTabBadges, accordionBadges };
|
||||
}),
|
||||
[modelConfig]
|
||||
);
|
||||
const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges);
|
||||
const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({
|
||||
id: 'generation-settings-advanced',
|
||||
defaultIsOpen: false,
|
||||
|
@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getModelConfig: build.query<AnyModelConfig, string>({
|
||||
query: (key) => buildModelsUrl(`i/${key}`),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiTagDescription[] = ['Model'];
|
||||
|
||||
if (result) {
|
||||
tags.push({ type: 'ModelConfig', id: result.key });
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
}),
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
|
||||
export const {
|
||||
useGetModelConfigQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key);
|
||||
|
||||
export const useSelectedModelConfig = () => {
|
||||
const key = useAppSelector(selectModelKey);
|
||||
const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken);
|
||||
|
||||
return modelConfig;
|
||||
};
|
@ -26,6 +26,7 @@ export const tagTypes = [
|
||||
'BatchStatus',
|
||||
'InvocationCacheStatus',
|
||||
'Model',
|
||||
'ModelConfig',
|
||||
'T2IAdapterModel',
|
||||
'MainModel',
|
||||
'VaeModel',
|
||||
|
Loading…
Reference in New Issue
Block a user