mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
disable submodels that have incompatible base models
This commit is contained in:
parent
6356dc335f
commit
b9a1aa38e3
@ -32,7 +32,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
searchable={searchable}
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { Tooltip, Text } from '@mantine/core';
|
||||
import { forwardRef, memo } from 'react';
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
label: string;
|
||||
description?: string;
|
||||
tooltip?: string;
|
||||
}
|
||||
|
||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, tooltip, description, ...others }: ItemProps, ref) => (
|
||||
<div ref={ref} {...others}>
|
||||
{tooltip ? (
|
||||
<Tooltip.Floating label={tooltip}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip.Floating>
|
||||
) : (
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
|
||||
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
|
||||
|
||||
export default memo(IAIMantineSelectItemWithTooltip);
|
@ -8,15 +8,12 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
PropsWithChildren,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
import { RootState } from '../../../app/store/store';
|
||||
import { useAppSelector } from '../../../app/store/storeHooks';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
type EmbeddingSelectItem = {
|
||||
label: string;
|
||||
@ -35,6 +32,10 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const currentMainModel = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!embeddingQueryData) {
|
||||
return [];
|
||||
@ -49,11 +50,14 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
description: embedding.description,
|
||||
...(currentMainModel?.base_model !== embedding.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [embeddingQueryData]);
|
||||
}, [embeddingQueryData, currentMainModel?.base_model]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
@ -108,7 +112,7 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No Matching Embeddings"
|
||||
itemComponent={SelectItem}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
@ -124,28 +128,3 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
};
|
||||
|
||||
export default ParamEmbeddingPopover;
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
||||
|
@ -1,13 +1,14 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { RootState, stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { forwardRef, useCallback, useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import { loraAdded } from '../store/loraSlice';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
type LoraSelectItem = {
|
||||
label: string;
|
||||
@ -28,6 +29,10 @@ const ParamLoraSelect = () => {
|
||||
const { loras } = useAppSelector(selector);
|
||||
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
||||
|
||||
const currentMainModel = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!lorasQueryData) {
|
||||
return [];
|
||||
@ -43,12 +48,15 @@ const ParamLoraSelect = () => {
|
||||
data.push({
|
||||
value: id,
|
||||
label: lora.name,
|
||||
description: lora.description,
|
||||
description: 'This is a lora',
|
||||
...(currentMainModel?.base_model !== lora.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [loras, lorasQueryData]);
|
||||
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
@ -78,7 +86,7 @@ const ParamLoraSelect = () => {
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No matching LoRAs"
|
||||
itemComponent={SelectItem}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: LoraSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
@ -89,29 +97,4 @@ const ParamLoraSelect = () => {
|
||||
);
|
||||
};
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
||||
|
||||
export default ParamLoraSelect;
|
||||
|
@ -49,7 +49,7 @@ export interface GenerationState {
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: ModelParam;
|
||||
vae: ModelParam;
|
||||
vae: VAEParam;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
@ -84,7 +84,7 @@ export const initialGenerationState: GenerationState = {
|
||||
horizontalSymmetrySteps: 0,
|
||||
verticalSymmetrySteps: 0,
|
||||
model: null,
|
||||
vae: null,
|
||||
vae: '',
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
@ -224,8 +224,7 @@ export const generationSlice = createSlice({
|
||||
state.model = { id: action.payload, base_model, name, type };
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
state.vae = { id: action.payload, base_model, name, type };
|
||||
state.vae = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
state.clipSkip = action.payload;
|
||||
|
@ -141,10 +141,15 @@ const zModel = z.object({
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ModelParam = z.infer<typeof zModel> | null;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zVAE = z.string();
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VAEParam = z.infer<typeof zModel> | null;
|
||||
export type VAEParam = z.infer<typeof zVAE>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
|
@ -11,6 +11,7 @@ import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||
|
||||
const VAESelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -18,7 +19,11 @@ const VAESelect = () => {
|
||||
|
||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||
|
||||
const currentModel = useAppSelector(
|
||||
const currentMainModel = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const selectedVae = useAppSelector(
|
||||
(state: RootState) => state.generation.vae
|
||||
);
|
||||
|
||||
@ -44,15 +49,18 @@ const VAESelect = () => {
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
...(currentMainModel?.base_model !== model.base_model
|
||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||
: {}),
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [vaeModels]);
|
||||
}, [vaeModels, currentMainModel?.base_model]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => vaeModels?.entities[currentModel?.id || ''],
|
||||
[vaeModels?.entities, currentModel]
|
||||
const selectedVaeModel = useMemo(
|
||||
() => vaeModels?.entities[selectedVae],
|
||||
[vaeModels?.entities, selectedVae]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
@ -66,17 +74,18 @@ const VAESelect = () => {
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) {
|
||||
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
|
||||
return;
|
||||
}
|
||||
handleChangeModel('auto');
|
||||
}, [handleChangeModel, vaeModels?.ids, currentModel?.id]);
|
||||
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
tooltip={selectedVaeModel?.description}
|
||||
label={t('modelManager.vae')}
|
||||
value={currentModel?.id}
|
||||
value={selectedVae}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
|
Loading…
Reference in New Issue
Block a user