diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 05a3c14cb4..b508c9424d 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -14,14 +14,16 @@ import type { LoRA } from 'features/lora/store/loraSlice'; import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; import { memo, useCallback } from 'react'; import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type LoRACardProps = { lora: LoRA; }; export const LoRACard = memo((props: LoRACardProps) => { - const dispatch = useAppDispatch(); const { lora } = props; + const dispatch = useAppDispatch(); + const { data: loraConfig } = useGetModelConfigQuery(lora.key); const handleChange = useCallback( (v: number) => { @@ -43,7 +45,7 @@ export const LoRACard = memo((props: LoRACardProps) => { - {lora.key} + {loraConfig?.name ?? lora.key.substring(0, 8)} diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index d5d04deaa5..b195ce4434 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -67,6 +67,8 @@ export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ key: z.string().min(1), }); +export const isModelIdentifier = (field: unknown): field is ModelIdentifier => + zModelIdentifier.safeParse(field).success; export const zModelFieldBase = zModelIdentifier; export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; @@ -141,7 +143,7 @@ export type VAEField = z.infer; // #region Control Adapters export const zControlField = z.object({ image: zImageField, - control_model: zControlNetModelField, + control_model: zModelFieldBase, control_weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -152,7 +154,7 @@ export type ControlField = z.infer; export const zIPAdapterField = z.object({ image: zImageField, - ip_adapter_model: zIPAdapterModelField, + ip_adapter_model: zModelFieldBase, weight: z.number(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -161,7 +163,7 @@ export type IPAdapterField = z.infer; export const zT2IAdapterField = z.object({ image: zImageField, - t2i_adapter_model: zT2IAdapterModelField, + t2i_adapter_model: zModelFieldBase, weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index abd8ee2810..b30d5df147 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -4,7 +4,7 @@ import { zControlNetModelField, zIPAdapterModelField, zLoRAModelField, - zMainModelField, + zModelIdentifierWithBase, zSchedulerField, zSDXLRefinerModelField, zT2IAdapterModelField, @@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati // #endregion // #region Model -export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); +export const zParameterModel = zModelIdentifierWithBase; export type ParameterModel = z.infer; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; // #endregion diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index fc8c54576c..8b10d9bddd 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -1,5 +1,6 @@ import type { FormLabelProps } from '@invoke-ai/ui-library'; import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier'; @@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; const formLabelProps: FormLabelProps = { minW: '9.2rem', @@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = { flexGrow: 1, }; -const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => { - const badges: (string | number)[] = []; - if (generation.vae) { - // TODO(MM2): Fetch the vae name - let vaeBadge = generation.vae.key; - if (generation.vaePrecision === 'fp16') { - vaeBadge += ` ${generation.vaePrecision}`; - } - badges.push(vaeBadge); - } else if (generation.vaePrecision === 'fp16') { - badges.push(`VAE ${generation.vaePrecision}`); - } - if (generation.clipSkip) { - badges.push(`Skip ${generation.clipSkip}`); - } - if (generation.cfgRescaleMultiplier) { - badges.push(`Rescale ${generation.cfgRescaleMultiplier}`); - } - if (generation.seamlessXAxis || generation.seamlessYAxis) { - badges.push('seamless'); - } - return badges; -}); - export const AdvancedSettingsAccordion = memo(() => { + const vaeKey = useAppSelector((state) => state.generation.vae?.key); + const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); + const selectBadges = useMemo( + () => + createMemoizedSelector(selectGenerationSlice, (generation) => { + const badges: (string | number)[] = []; + if (vaeConfig) { + let vaeBadge = vaeConfig.name; + if (generation.vaePrecision === 'fp16') { + vaeBadge += ` ${generation.vaePrecision}`; + } + badges.push(vaeBadge); + } else if (generation.vaePrecision === 'fp16') { + badges.push(`VAE ${generation.vaePrecision}`); + } + if (generation.clipSkip) { + badges.push(`Skip ${generation.clipSkip}`); + } + if (generation.cfgRescaleMultiplier) { + badges.push(`Rescale ${generation.cfgRescaleMultiplier}`); + } + if (generation.seamlessXAxis || generation.seamlessYAxis) { + badges.push('seamless'); + } + return badges; + }), + [vaeConfig] + ); const badges = useAppSelector(selectBadges); const { t } = useTranslation(); const { isOpen, onToggle } = useStandaloneAccordionToggle({ diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index cda7dcf6e9..d57e48f11e 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -12,6 +12,7 @@ import { } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; import { selectLoraSlice } from 'features/lora/store/loraSlice'; @@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamSteps from 'features/parameters/components/Core/ParamSteps'; import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect'; -import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { filter } from 'lodash-es'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; const formLabelProps: FormLabelProps = { minW: '4rem', }; -const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => { - const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; - const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : []; - const accordionBadges: (string | number)[] = []; - // TODO(MM2): fetch model name - if (generation.model) { - accordionBadges.push(generation.model.key); - accordionBadges.push(generation.model.base); - } - - return { loraTabBadges, accordionBadges }; -}); - export const GenerationSettingsAccordion = memo(() => { const { t } = useTranslation(); - const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector); + const modelConfig = useSelectedModelConfig(); + const selectBadges = useMemo( + () => + createMemoizedSelector(selectLoraSlice, (lora) => { + const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; + const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY; + const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; + return { loraTabBadges, accordionBadges }; + }), + [modelConfig] + ); + const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges); const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ id: 'generation-settings-advanced', defaultIsOpen: false, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 46be42d9e5..666e0c707d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), + getModelConfig: build.query({ + query: (key) => buildModelsUrl(`i/${key}`), + providesTags: (result) => { + const tags: ApiTagDescription[] = ['Model']; + + if (result) { + tags.push({ type: 'ModelConfig', id: result.key }); + } + + return tags; + }, + }), syncModels: build.mutation({ query: () => { return { @@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({ }); export const { + useGetModelConfigQuery, useGetMainModelsQuery, useGetControlNetModelsQuery, useGetIPAdapterModelsQuery, diff --git a/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts b/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts new file mode 100644 index 0000000000..4a8d8d72e2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts @@ -0,0 +1,14 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; + +const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key); + +export const useSelectedModelConfig = () => { + const key = useAppSelector(selectModelKey); + const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken); + + return modelConfig; +}; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 1f567d7905..879cd1f8c4 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -26,6 +26,7 @@ export const tagTypes = [ 'BatchStatus', 'InvocationCacheStatus', 'Model', + 'ModelConfig', 'T2IAdapterModel', 'MainModel', 'VaeModel',