disable submodels that have incompatible base models

This commit is contained in:
Mary Hipp 2023-07-06 14:40:51 -04:00 committed by psychedelicious
parent 6356dc335f
commit b9a1aa38e3
7 changed files with 94 additions and 79 deletions

View File

@ -32,7 +32,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
ref={inputRef} ref={inputRef}
searchable={searchable} searchable={searchable}

View File

@ -0,0 +1,40 @@
import { Tooltip, Text } from '@mantine/core';
import { forwardRef, memo } from 'react';
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
label: string;
description?: string;
tooltip?: string;
}
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, ...others }: ItemProps, ref) => (
<div ref={ref} {...others}>
{tooltip ? (
<Tooltip.Floating label={tooltip}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</Tooltip.Floating>
) : (
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
)}
</div>
)
);
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
export default memo(IAIMantineSelectItemWithTooltip);

View File

@ -8,15 +8,12 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
import { RootState } from '../../../app/store/store';
import { useAppSelector } from '../../../app/store/storeHooks';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type EmbeddingSelectItem = { type EmbeddingSelectItem = {
label: string; label: string;
@ -35,6 +32,10 @@ const ParamEmbeddingPopover = (props: Props) => {
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery(); const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const data = useMemo(() => { const data = useMemo(() => {
if (!embeddingQueryData) { if (!embeddingQueryData) {
return []; return [];
@ -49,11 +50,14 @@ const ParamEmbeddingPopover = (props: Props) => {
value: embedding.name, value: embedding.name,
label: embedding.name, label: embedding.name,
description: embedding.description, description: embedding.description,
...(currentMainModel?.base_model !== embedding.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
}); });
}); });
return data; return data;
}, [embeddingQueryData]); }, [embeddingQueryData, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(
(v: string[]) => { (v: string[]) => {
@ -108,7 +112,7 @@ const ParamEmbeddingPopover = (props: Props) => {
data={data} data={data}
maxDropdownHeight={400} maxDropdownHeight={400}
nothingFound="No Matching Embeddings" nothingFound="No Matching Embeddings"
itemComponent={SelectItem} itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0} disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) => filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
@ -124,28 +128,3 @@ const ParamEmbeddingPopover = (props: Props) => {
}; };
export default ParamEmbeddingPopover; export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@ -1,13 +1,14 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { forwardRef, useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice'; import { loraAdded } from '../store/loraSlice';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
type LoraSelectItem = { type LoraSelectItem = {
label: string; label: string;
@ -28,6 +29,10 @@ const ParamLoraSelect = () => {
const { loras } = useAppSelector(selector); const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useGetLoRAModelsQuery(); const { data: lorasQueryData } = useGetLoRAModelsQuery();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const data = useMemo(() => { const data = useMemo(() => {
if (!lorasQueryData) { if (!lorasQueryData) {
return []; return [];
@ -43,12 +48,15 @@ const ParamLoraSelect = () => {
data.push({ data.push({
value: id, value: id,
label: lora.name, label: lora.name,
description: lora.description, description: 'This is a lora',
...(currentMainModel?.base_model !== lora.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
}); });
}); });
return data; return data;
}, [loras, lorasQueryData]); }, [loras, lorasQueryData, currentMainModel?.base_model]);
const handleChange = useCallback( const handleChange = useCallback(
(v: string[]) => { (v: string[]) => {
@ -78,7 +86,7 @@ const ParamLoraSelect = () => {
data={data} data={data}
maxDropdownHeight={400} maxDropdownHeight={400}
nothingFound="No matching LoRAs" nothingFound="No matching LoRAs"
itemComponent={SelectItem} itemComponent={IAIMantineSelectItemWithTooltip}
disabled={data.length === 0} disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) => filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
@ -89,29 +97,4 @@ const ParamLoraSelect = () => {
); );
}; };
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default ParamLoraSelect; export default ParamLoraSelect;

View File

@ -49,7 +49,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
vae: ModelParam; vae: VAEParam;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
clipSkip: number; clipSkip: number;
@ -84,7 +84,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: null, model: null,
vae: null, vae: '',
seamlessXAxis: false, seamlessXAxis: false,
seamlessYAxis: false, seamlessYAxis: false,
clipSkip: 0, clipSkip: 0,
@ -224,8 +224,7 @@ export const generationSlice = createSlice({
state.model = { id: action.payload, base_model, name, type }; state.model = { id: action.payload, base_model, name, type };
}, },
vaeSelected: (state, action: PayloadAction<string>) => { vaeSelected: (state, action: PayloadAction<string>) => {
const [base_model, type, name] = action.payload.split('/'); state.vae = action.payload;
state.vae = { id: action.payload, base_model, name, type };
}, },
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload; state.clipSkip = action.payload;

View File

@ -141,10 +141,15 @@ const zModel = z.object({
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type ModelParam = z.infer<typeof zModel> | null; export type ModelParam = z.infer<typeof zModel> | null;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/** /**
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type VAEParam = z.infer<typeof zModel> | null; export type VAEParam = z.infer<typeof zVAE>;
/** /**
* Validates/type-guards a value as a model parameter * Validates/type-guards a value as a model parameter
*/ */

View File

@ -11,6 +11,7 @@ import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice'; import { vaeSelected } from 'features/parameters/store/generationSlice';
import { MODEL_TYPE_MAP } from './ModelSelect'; import { MODEL_TYPE_MAP } from './ModelSelect';
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
const VAESelect = () => { const VAESelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -18,7 +19,11 @@ const VAESelect = () => {
const { data: vaeModels } = useGetVaeModelsQuery(); const { data: vaeModels } = useGetVaeModelsQuery();
const currentModel = useAppSelector( const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
const selectedVae = useAppSelector(
(state: RootState) => state.generation.vae (state: RootState) => state.generation.vae
); );
@ -44,15 +49,18 @@ const VAESelect = () => {
value: id, value: id,
label: model.name, label: model.name,
group: MODEL_TYPE_MAP[model.base_model], group: MODEL_TYPE_MAP[model.base_model],
...(currentMainModel?.base_model !== model.base_model
? { disabled: true, tooltip: 'Incompatible base model' }
: {}),
}); });
}); });
return data; return data;
}, [vaeModels]); }, [vaeModels, currentMainModel?.base_model]);
const selectedModel = useMemo( const selectedVaeModel = useMemo(
() => vaeModels?.entities[currentModel?.id || ''], () => vaeModels?.entities[selectedVae],
[vaeModels?.entities, currentModel] [vaeModels?.entities, selectedVae]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -66,17 +74,18 @@ const VAESelect = () => {
); );
useEffect(() => { useEffect(() => {
if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) { if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
return; return;
} }
handleChangeModel('auto'); handleChangeModel('auto');
}, [handleChangeModel, vaeModels?.ids, currentModel?.id]); }, [handleChangeModel, vaeModels?.ids, selectedVae]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}
value={currentModel?.id} value={selectedVae}
placeholder="Pick one" placeholder="Pick one"
data={data} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}