disable submodels that have incompatible base models

This commit is contained in:
Mary Hipp 2023-07-06 14:40:51 -04:00 committed by psychedelicious
parent 6356dc335f
commit b9a1aa38e3
7 changed files with 94 additions and 79 deletions

View File

@ -32,7 +32,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { colorMode } = useColorMode();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
ref={inputRef}
searchable={searchable}

View File

@ -0,0 +1,40 @@
import { Tooltip, Text } from '@mantine/core';
import { forwardRef, memo } from 'react';
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
label: string;
description?: string;
tooltip?: string;
}
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, ...others }: ItemProps, ref) => (
<div ref={ref} {...others}>
{tooltip ? (
<Tooltip.Floating label={tooltip}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</Tooltip.Floating>
) : (
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
)}
</div>
)
);
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
export default memo(IAIMantineSelectItemWithTooltip);

View File

@ -8,15 +8,12 @@ import {
} from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import {
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { RootState } from '../../../app/store/store';
import { useAppSelector } from '../../../app/store/storeHooks';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type EmbeddingSelectItem = {
label: string;
@ -35,6 +32,10 @@ const ParamEmbeddingPopover = (props: Props) => {
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null);
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const data = useMemo(() => {
if (!embeddingQueryData) {
return [];
@ -49,11 +50,14 @@ const ParamEmbeddingPopover = (props: Props) => {
value: embedding.name,
label: embedding.name,
description: embedding.description,
...(currentMainModel?.base_model !== embedding.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
});
});
return data;
}, [embeddingQueryData]);
}, [embeddingQueryData, currentMainModel?.base_model]);
const handleChange = useCallback(
(v: string[]) => {
@ -108,7 +112,7 @@ const ParamEmbeddingPopover = (props: Props) => {
data={data}
maxDropdownHeight={400}
nothingFound="No Matching Embeddings"
itemComponent={SelectItem}
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
@ -124,28 +128,3 @@ const ParamEmbeddingPopover = (props: Props) => {
};
export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@ -1,13 +1,14 @@
import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import { forwardRef, useCallback, useMemo } from 'react';
import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type LoraSelectItem = {
label: string;
@ -28,6 +29,10 @@ const ParamLoraSelect = () => {
const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useGetLoRAModelsQuery();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const data = useMemo(() => {
if (!lorasQueryData) {
return [];
@ -43,12 +48,15 @@ const ParamLoraSelect = () => {
data.push({
value: id,
label: lora.name,
description: lora.description,
description: 'This is a lora',
...(currentMainModel?.base_model !== lora.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
});
});
return data;
}, [loras, lorasQueryData]);
}, [loras, lorasQueryData, currentMainModel?.base_model]);
const handleChange = useCallback(
(v: string[]) => {
@ -78,7 +86,7 @@ const ParamLoraSelect = () => {
data={data}
maxDropdownHeight={400}
nothingFound="No matching LoRAs"
itemComponent={SelectItem}
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
@ -89,29 +97,4 @@ const ParamLoraSelect = () => {
);
};
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default ParamLoraSelect;

View File

@ -49,7 +49,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
vae: ModelParam;
vae: VAEParam;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
@ -84,7 +84,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: null,
vae: null,
vae: '',
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
@ -224,8 +224,7 @@ export const generationSlice = createSlice({
state.model = { id: action.payload, base_model, name, type };
},
vaeSelected: (state, action: PayloadAction<string>) => {
const [base_model, type, name] = action.payload.split('/');
state.vae = { id: action.payload, base_model, name, type };
state.vae = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload;

View File

@ -141,10 +141,15 @@ const zModel = z.object({
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel> | null;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zModel> | null;
export type VAEParam = z.infer<typeof zVAE>;
/**
* Validates/type-guards a value as a model parameter
*/

View File

@ -11,6 +11,7 @@ import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice';
import { MODEL_TYPE_MAP } from './ModelSelect';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
const VAESelect = () => {
const dispatch = useAppDispatch();
@ -18,7 +19,11 @@ const VAESelect = () => {
const { data: vaeModels } = useGetVaeModelsQuery();
const currentModel = useAppSelector(
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const selectedVae = useAppSelector(
(state: RootState) => state.generation.vae
);
@ -44,15 +49,18 @@ const VAESelect = () => {
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
...(currentMainModel?.base_model !== model.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
});
});
return data;
}, [vaeModels]);
}, [vaeModels, currentMainModel?.base_model]);
const selectedModel = useMemo(
() => vaeModels?.entities[currentModel?.id || ''],
[vaeModels?.entities, currentModel]
const selectedVaeModel = useMemo(
() => vaeModels?.entities[selectedVae],
[vaeModels?.entities, selectedVae]
);
const handleChangeModel = useCallback(
@ -66,17 +74,18 @@ const VAESelect = () => {
);
useEffect(() => {
if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) {
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
return;
}
handleChangeModel('auto');
}, [handleChangeModel, vaeModels?.ids, currentModel?.id]);
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}
value={currentModel?.id}
value={selectedVae}
placeholder="Pick one"
data={data}
onChange={handleChangeModel}