mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
disable submodels that have incompatible base models
This commit is contained in:
parent
6356dc335f
commit
b9a1aa38e3
@ -32,7 +32,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
const { colorMode } = useColorMode();
|
const { colorMode } = useColorMode();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||||
<MultiSelect
|
<MultiSelect
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
searchable={searchable}
|
searchable={searchable}
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
import { Tooltip, Text } from '@mantine/core';
|
||||||
|
import { forwardRef, memo } from 'react';
|
||||||
|
|
||||||
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
|
label: string;
|
||||||
|
description?: string;
|
||||||
|
tooltip?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||||
|
({ label, tooltip, description, ...others }: ItemProps, ref) => (
|
||||||
|
<div ref={ref} {...others}>
|
||||||
|
{tooltip ? (
|
||||||
|
<Tooltip.Floating label={tooltip}>
|
||||||
|
<div>
|
||||||
|
<Text>{label}</Text>
|
||||||
|
{description && (
|
||||||
|
<Text size="xs" color="base.600">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</Tooltip.Floating>
|
||||||
|
) : (
|
||||||
|
<div>
|
||||||
|
<Text>{label}</Text>
|
||||||
|
{description && (
|
||||||
|
<Text size="xs" color="base.600">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
|
||||||
|
|
||||||
|
export default memo(IAIMantineSelectItemWithTooltip);
|
@ -8,15 +8,12 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import {
|
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||||
PropsWithChildren,
|
|
||||||
forwardRef,
|
|
||||||
useCallback,
|
|
||||||
useMemo,
|
|
||||||
useRef,
|
|
||||||
} from 'react';
|
|
||||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||||
|
import { RootState } from '../../../app/store/store';
|
||||||
|
import { useAppSelector } from '../../../app/store/storeHooks';
|
||||||
|
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||||
|
|
||||||
type EmbeddingSelectItem = {
|
type EmbeddingSelectItem = {
|
||||||
label: string;
|
label: string;
|
||||||
@ -35,6 +32,10 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
const currentMainModel = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!embeddingQueryData) {
|
if (!embeddingQueryData) {
|
||||||
return [];
|
return [];
|
||||||
@ -49,11 +50,14 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
value: embedding.name,
|
value: embedding.name,
|
||||||
label: embedding.name,
|
label: embedding.name,
|
||||||
description: embedding.description,
|
description: embedding.description,
|
||||||
|
...(currentMainModel?.base_model !== embedding.base_model
|
||||||
|
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||||
|
: {}),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [embeddingQueryData]);
|
}, [embeddingQueryData, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string[]) => {
|
(v: string[]) => {
|
||||||
@ -108,7 +112,7 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
data={data}
|
data={data}
|
||||||
maxDropdownHeight={400}
|
maxDropdownHeight={400}
|
||||||
nothingFound="No Matching Embeddings"
|
nothingFound="No Matching Embeddings"
|
||||||
itemComponent={SelectItem}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
@ -124,28 +128,3 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export default ParamEmbeddingPopover;
|
export default ParamEmbeddingPopover;
|
||||||
|
|
||||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
|
||||||
value: string;
|
|
||||||
label: string;
|
|
||||||
description?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
|
||||||
({ label, description, ...others }: ItemProps, ref) => {
|
|
||||||
return (
|
|
||||||
<div ref={ref} {...others}>
|
|
||||||
<div>
|
|
||||||
<Text>{label}</Text>
|
|
||||||
{description && (
|
|
||||||
<Text size="xs" color="base.600">
|
|
||||||
{description}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
SelectItem.displayName = 'SelectItem';
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import { Flex, Text } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { RootState, stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { forwardRef, useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { loraAdded } from '../store/loraSlice';
|
import { loraAdded } from '../store/loraSlice';
|
||||||
|
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||||
|
|
||||||
type LoraSelectItem = {
|
type LoraSelectItem = {
|
||||||
label: string;
|
label: string;
|
||||||
@ -28,6 +29,10 @@ const ParamLoraSelect = () => {
|
|||||||
const { loras } = useAppSelector(selector);
|
const { loras } = useAppSelector(selector);
|
||||||
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
const { data: lorasQueryData } = useGetLoRAModelsQuery();
|
||||||
|
|
||||||
|
const currentMainModel = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!lorasQueryData) {
|
if (!lorasQueryData) {
|
||||||
return [];
|
return [];
|
||||||
@ -43,12 +48,15 @@ const ParamLoraSelect = () => {
|
|||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: lora.name,
|
label: lora.name,
|
||||||
description: lora.description,
|
description: 'This is a lora',
|
||||||
|
...(currentMainModel?.base_model !== lora.base_model
|
||||||
|
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||||
|
: {}),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [loras, lorasQueryData]);
|
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string[]) => {
|
(v: string[]) => {
|
||||||
@ -78,7 +86,7 @@ const ParamLoraSelect = () => {
|
|||||||
data={data}
|
data={data}
|
||||||
maxDropdownHeight={400}
|
maxDropdownHeight={400}
|
||||||
nothingFound="No matching LoRAs"
|
nothingFound="No matching LoRAs"
|
||||||
itemComponent={SelectItem}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
filter={(value, selected, item: LoraSelectItem) =>
|
filter={(value, selected, item: LoraSelectItem) =>
|
||||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
@ -89,29 +97,4 @@ const ParamLoraSelect = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
|
||||||
value: string;
|
|
||||||
label: string;
|
|
||||||
description?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
|
||||||
({ label, description, ...others }: ItemProps, ref) => {
|
|
||||||
return (
|
|
||||||
<div ref={ref} {...others}>
|
|
||||||
<div>
|
|
||||||
<Text>{label}</Text>
|
|
||||||
{description && (
|
|
||||||
<Text size="xs" color="base.600">
|
|
||||||
{description}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
SelectItem.displayName = 'SelectItem';
|
|
||||||
|
|
||||||
export default ParamLoraSelect;
|
export default ParamLoraSelect;
|
||||||
|
@ -49,7 +49,7 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
vae: ModelParam;
|
vae: VAEParam;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
clipSkip: number;
|
clipSkip: number;
|
||||||
@ -84,7 +84,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: null,
|
model: null,
|
||||||
vae: null,
|
vae: '',
|
||||||
seamlessXAxis: false,
|
seamlessXAxis: false,
|
||||||
seamlessYAxis: false,
|
seamlessYAxis: false,
|
||||||
clipSkip: 0,
|
clipSkip: 0,
|
||||||
@ -224,8 +224,7 @@ export const generationSlice = createSlice({
|
|||||||
state.model = { id: action.payload, base_model, name, type };
|
state.model = { id: action.payload, base_model, name, type };
|
||||||
},
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
const [base_model, type, name] = action.payload.split('/');
|
state.vae = action.payload;
|
||||||
state.vae = { id: action.payload, base_model, name, type };
|
|
||||||
},
|
},
|
||||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||||
state.clipSkip = action.payload;
|
state.clipSkip = action.payload;
|
||||||
|
@ -141,10 +141,15 @@ const zModel = z.object({
|
|||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type ModelParam = z.infer<typeof zModel> | null;
|
export type ModelParam = z.infer<typeof zModel> | null;
|
||||||
|
/**
|
||||||
|
* Zod schema for VAE parameter
|
||||||
|
* TODO: Make this a dynamically generated enum?
|
||||||
|
*/
|
||||||
|
export const zVAE = z.string();
|
||||||
/**
|
/**
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type VAEParam = z.infer<typeof zModel> | null;
|
export type VAEParam = z.infer<typeof zVAE>;
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
|
@ -11,6 +11,7 @@ import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||||
|
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
||||||
|
|
||||||
const VAESelect = () => {
|
const VAESelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -18,7 +19,11 @@ const VAESelect = () => {
|
|||||||
|
|
||||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
|
|
||||||
const currentModel = useAppSelector(
|
const currentMainModel = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
|
const selectedVae = useAppSelector(
|
||||||
(state: RootState) => state.generation.vae
|
(state: RootState) => state.generation.vae
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -44,15 +49,18 @@ const VAESelect = () => {
|
|||||||
value: id,
|
value: id,
|
||||||
label: model.name,
|
label: model.name,
|
||||||
group: MODEL_TYPE_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
...(currentMainModel?.base_model !== model.base_model
|
||||||
|
? { disabled: true, tooltip: 'Incompatible base model' }
|
||||||
|
: {}),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [vaeModels]);
|
}, [vaeModels, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedVaeModel = useMemo(
|
||||||
() => vaeModels?.entities[currentModel?.id || ''],
|
() => vaeModels?.entities[selectedVae],
|
||||||
[vaeModels?.entities, currentModel]
|
[vaeModels?.entities, selectedVae]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -66,17 +74,18 @@ const VAESelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) {
|
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
handleChangeModel('auto');
|
handleChangeModel('auto');
|
||||||
}, [handleChangeModel, vaeModels?.ids, currentModel?.id]);
|
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
|
tooltip={selectedVaeModel?.description}
|
||||||
label={t('modelManager.vae')}
|
label={t('modelManager.vae')}
|
||||||
value={currentModel?.id}
|
value={selectedVae}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
|
Loading…
Reference in New Issue
Block a user