mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): finalize base model compatibility for lora, ti, vae
This commit is contained in:
parent
a9a4081f51
commit
8457fcf7d3
@ -1,11 +1,12 @@
|
|||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
modelChanged,
|
modelChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
|
||||||
import { makeToast } from 'app/components/Toaster';
|
|
||||||
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
||||||
|
|
||||||
export const addModelSelectedListener = () => {
|
export const addModelSelectedListener = () => {
|
||||||
@ -24,12 +25,18 @@ export const addModelSelectedListener = () => {
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
dispatch(vaeSelected('auto'));
|
dispatch(vaeSelected(null));
|
||||||
dispatch(lorasCleared());
|
dispatch(lorasCleared());
|
||||||
// TODO: controlnet cleared
|
// TODO: controlnet cleared
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(modelChanged({ id: action.payload, base_model, name, type }));
|
const newModel = zMainModel.parse({
|
||||||
|
id: action.payload,
|
||||||
|
base_model,
|
||||||
|
name,
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch(modelChanged(newModel));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -66,6 +66,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
'&[data-disabled]': {
|
'&[data-disabled]': {
|
||||||
backgroundColor: mode(base300, base700)(colorMode),
|
backgroundColor: mode(base300, base700)(colorMode),
|
||||||
color: mode(base600, base400)(colorMode),
|
color: mode(base600, base400)(colorMode),
|
||||||
|
cursor: 'not-allowed',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
value: {
|
value: {
|
||||||
@ -108,6 +109,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
color: mode('white', base50)(colorMode),
|
color: mode('white', base50)(colorMode),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
'&[data-disabled]': {
|
||||||
|
color: mode(base500, base600)(colorMode),
|
||||||
|
cursor: 'not-allowed',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
rightSection: {
|
rightSection: {
|
||||||
width: 24,
|
width: 24,
|
||||||
|
@ -67,6 +67,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
'&[data-disabled]': {
|
'&[data-disabled]': {
|
||||||
backgroundColor: mode(base300, base700)(colorMode),
|
backgroundColor: mode(base300, base700)(colorMode),
|
||||||
color: mode(base600, base400)(colorMode),
|
color: mode(base600, base400)(colorMode),
|
||||||
|
cursor: 'not-allowed',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
value: {
|
value: {
|
||||||
@ -109,6 +110,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
color: mode('white', base50)(colorMode),
|
color: mode('white', base50)(colorMode),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
'&[data-disabled]': {
|
||||||
|
color: mode(base500, base600)(colorMode),
|
||||||
|
cursor: 'not-allowed',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
rightSection: {
|
rightSection: {
|
||||||
width: 32,
|
width: 32,
|
||||||
|
@ -1,37 +1,28 @@
|
|||||||
import { Tooltip, Text } from '@mantine/core';
|
import { Box, Tooltip } from '@chakra-ui/react';
|
||||||
|
import { Text } from '@mantine/core';
|
||||||
import { forwardRef, memo } from 'react';
|
import { forwardRef, memo } from 'react';
|
||||||
|
|
||||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
label: string;
|
label: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||||
({ label, tooltip, description, ...others }: ItemProps, ref) => (
|
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
|
||||||
<div ref={ref} {...others}>
|
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||||
{tooltip ? (
|
<Box ref={ref} {...others}>
|
||||||
<Tooltip.Floating label={tooltip}>
|
<Box>
|
||||||
<div>
|
|
||||||
<Text>{label}</Text>
|
|
||||||
{description && (
|
|
||||||
<Text size="xs" color="base.600">
|
|
||||||
{description}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</Tooltip.Floating>
|
|
||||||
) : (
|
|
||||||
<div>
|
|
||||||
<Text>{label}</Text>
|
<Text>{label}</Text>
|
||||||
{description && (
|
{description && (
|
||||||
<Text size="xs" color="base.600">
|
<Text size="xs" color="base.600">
|
||||||
{description}
|
{description}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
</div>
|
</Box>
|
||||||
)}
|
</Box>
|
||||||
</div>
|
</Tooltip>
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -6,20 +6,16 @@ import {
|
|||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
Text,
|
Text,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||||
import { RootState } from '../../../app/store/store';
|
|
||||||
import { useAppSelector } from '../../../app/store/storeHooks';
|
|
||||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
|
||||||
|
|
||||||
type EmbeddingSelectItem = {
|
|
||||||
label: string;
|
|
||||||
value: string;
|
|
||||||
description?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
type Props = PropsWithChildren & {
|
type Props = PropsWithChildren & {
|
||||||
onSelect: (v: string) => void;
|
onSelect: (v: string) => void;
|
||||||
@ -41,22 +37,27 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: EmbeddingSelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(embeddingQueryData.entities, (embedding, _) => {
|
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||||
if (!embedding) return;
|
if (!embedding) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const disabled = currentMainModel?.base_model !== embedding.base_model;
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: embedding.name,
|
value: embedding.name,
|
||||||
label: embedding.name,
|
label: embedding.name,
|
||||||
description: embedding.description,
|
group: MODEL_TYPE_MAP[embedding.base_model],
|
||||||
...(currentMainModel?.base_model !== embedding.base_model
|
disabled,
|
||||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
tooltip: disabled
|
||||||
: {}),
|
? `Incompatible base model: ${embedding.base_model}`
|
||||||
|
: undefined,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||||
}, [embeddingQueryData, currentMainModel?.base_model]);
|
}, [embeddingQueryData, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
@ -114,8 +115,10 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
nothingFound="No Matching Embeddings"
|
nothingFound="No Matching Embeddings"
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
filter={(value, selected, item: SelectItem) =>
|
||||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label
|
||||||
|
?.toLowerCase()
|
||||||
|
.includes(value.toLowerCase().trim()) ||
|
||||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
}
|
}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
@ -1,20 +1,16 @@
|
|||||||
import { Flex, Text } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState, stateSelector } from 'app/store/store';
|
import { RootState, stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
|
import { loraAdded } from 'features/lora/store/loraSlice';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { loraAdded } from '../store/loraSlice';
|
|
||||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
|
||||||
|
|
||||||
type LoraSelectItem = {
|
|
||||||
label: string;
|
|
||||||
value: string;
|
|
||||||
description?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -38,24 +34,27 @@ const ParamLoraSelect = () => {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: LoraSelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(lorasQueryData.entities, (lora, id) => {
|
forEach(lorasQueryData.entities, (lora, id) => {
|
||||||
if (!lora || Boolean(id in loras)) {
|
if (!lora || Boolean(id in loras)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const disabled = currentMainModel?.base_model !== lora.base_model;
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: lora.name,
|
label: lora.name,
|
||||||
description: lora.description,
|
disabled,
|
||||||
...(currentMainModel?.base_model !== lora.base_model
|
group: MODEL_TYPE_MAP[lora.base_model],
|
||||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
tooltip: disabled
|
||||||
: {}),
|
? `Incompatible base model: ${lora.base_model}`
|
||||||
|
: undefined,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||||
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
}, [loras, lorasQueryData, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
@ -88,8 +87,8 @@ const ParamLoraSelect = () => {
|
|||||||
nothingFound="No matching LoRAs"
|
nothingFound="No matching LoRAs"
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
filter={(value, selected, item: LoraSelectItem) =>
|
filter={(value, selected, item: SelectItem) =>
|
||||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
}
|
}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
@ -1,18 +1,21 @@
|
|||||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { LoRAModelParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||||
|
import { BaseModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type Lora = {
|
export type Lora = {
|
||||||
id: string;
|
id: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
name: string;
|
name: string;
|
||||||
weight: number;
|
weight: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
|
export const defaultLoRAConfig = {
|
||||||
weight: 0.75,
|
weight: 0.75,
|
||||||
};
|
};
|
||||||
|
|
||||||
export type LoraState = {
|
export type LoraState = {
|
||||||
loras: Record<string, Lora>;
|
loras: Record<string, LoRAModelParam & { weight: number }>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const intialLoraState: LoraState = {
|
export const intialLoraState: LoraState = {
|
||||||
@ -24,14 +27,14 @@ export const loraSlice = createSlice({
|
|||||||
initialState: intialLoraState,
|
initialState: intialLoraState,
|
||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||||
const { name, id } = action.payload;
|
const { name, id, base_model } = action.payload;
|
||||||
state.loras[id] = { id, name, ...defaultLoRAConfig };
|
state.loras[id] = { id, name, base_model, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const id = action.payload;
|
const id = action.payload;
|
||||||
delete state.loras[id];
|
delete state.loras[id];
|
||||||
},
|
},
|
||||||
lorasCleared: (state, action: PayloadAction<>) => {
|
lorasCleared: (state) => {
|
||||||
state.loras = {};
|
state.loras = {};
|
||||||
},
|
},
|
||||||
loraWeightChanged: (
|
loraWeightChanged: (
|
||||||
|
@ -19,7 +19,7 @@ export const addVAEToGraph = (
|
|||||||
const { vae } = state.generation;
|
const { vae } = state.generation;
|
||||||
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
||||||
|
|
||||||
const isAutoVae = vae?.id === 'auto';
|
const isAutoVae = !vae;
|
||||||
|
|
||||||
if (!isAutoVae) {
|
if (!isAutoVae) {
|
||||||
graph.nodes[VAE_LOADER] = {
|
graph.nodes[VAE_LOADER] = {
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { isImageField } from 'services/api/guards';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import {
|
import {
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
@ -12,14 +16,10 @@ import {
|
|||||||
setSteps,
|
setSteps,
|
||||||
setWidth,
|
setWidth,
|
||||||
} from '../store/generationSlice';
|
} from '../store/generationSlice';
|
||||||
import { isImageField } from 'services/api/guards';
|
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import {
|
import {
|
||||||
isValidCfgScale,
|
isValidCfgScale,
|
||||||
isValidHeight,
|
isValidHeight,
|
||||||
isValidModel,
|
isValidMainModel,
|
||||||
isValidNegativePrompt,
|
isValidNegativePrompt,
|
||||||
isValidPositivePrompt,
|
isValidPositivePrompt,
|
||||||
isValidScheduler,
|
isValidScheduler,
|
||||||
@ -158,7 +158,7 @@ export const useRecallParameters = () => {
|
|||||||
*/
|
*/
|
||||||
const recallModel = useCallback(
|
const recallModel = useCallback(
|
||||||
(model: unknown) => {
|
(model: unknown) => {
|
||||||
if (!isValidModel(model)) {
|
if (!isValidMainModel(model)) {
|
||||||
parameterNotSetToast();
|
parameterNotSetToast();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -295,7 +295,7 @@ export const useRecallParameters = () => {
|
|||||||
if (isValidCfgScale(cfg_scale)) {
|
if (isValidCfgScale(cfg_scale)) {
|
||||||
dispatch(setCfgScale(cfg_scale));
|
dispatch(setCfgScale(cfg_scale));
|
||||||
}
|
}
|
||||||
if (isValidModel(model)) {
|
if (isValidMainModel(model)) {
|
||||||
dispatch(modelSelected(model));
|
dispatch(modelSelected(model));
|
||||||
}
|
}
|
||||||
if (isValidPositivePrompt(positive_conditioning)) {
|
if (isValidPositivePrompt(positive_conditioning)) {
|
||||||
|
@ -9,14 +9,16 @@ import { clipSkipMap } from '../components/Parameters/Advanced/ParamClipSkip';
|
|||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
ModelParam,
|
MainModelParam,
|
||||||
NegativePromptParam,
|
NegativePromptParam,
|
||||||
PositivePromptParam,
|
PositivePromptParam,
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
SeedParam,
|
SeedParam,
|
||||||
StepsParam,
|
StepsParam,
|
||||||
StrengthParam,
|
StrengthParam,
|
||||||
|
VaeModelParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
|
zMainModel,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
@ -48,8 +50,8 @@ export interface GenerationState {
|
|||||||
shouldUseSymmetry: boolean;
|
shouldUseSymmetry: boolean;
|
||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: MainModelParam | null;
|
||||||
vae: VAEParam;
|
vae: VaeModelParam | null;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
clipSkip: number;
|
clipSkip: number;
|
||||||
@ -84,7 +86,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: null,
|
model: null,
|
||||||
vae: '',
|
vae: null,
|
||||||
seamlessXAxis: false,
|
seamlessXAxis: false,
|
||||||
seamlessYAxis: false,
|
seamlessYAxis: false,
|
||||||
clipSkip: 0,
|
clipSkip: 0,
|
||||||
@ -221,12 +223,17 @@ export const generationSlice = createSlice({
|
|||||||
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
|
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
|
||||||
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||||
|
|
||||||
state.model = { id: action.payload, base_model, name, type };
|
state.model = zMainModel.parse({
|
||||||
|
id: action.payload,
|
||||||
|
base_model,
|
||||||
|
name,
|
||||||
|
type,
|
||||||
|
});
|
||||||
},
|
},
|
||||||
modelChanged: (state, action: PayloadAction<ModelParam>) => {
|
modelChanged: (state, action: PayloadAction<MainModelParam>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
vaeSelected: (state, action: PayloadAction<VaeModelParam | null>) => {
|
||||||
state.vae = action.payload;
|
state.vae = action.payload;
|
||||||
},
|
},
|
||||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||||
@ -236,14 +243,14 @@ export const generationSlice = createSlice({
|
|||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
|
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||||
state.model = {
|
state.model = zMainModel.parse({
|
||||||
id: defaultModel,
|
id: defaultModel,
|
||||||
name: model_name,
|
name: model_name,
|
||||||
type: model_type,
|
base_model,
|
||||||
base_model: base_model,
|
});
|
||||||
};
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||||
|
@ -126,35 +126,63 @@ export type HeightParam = z.infer<typeof zHeight>;
|
|||||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||||
zHeight.safeParse(val).success;
|
zHeight.safeParse(val).success;
|
||||||
|
|
||||||
|
const zBaseModel = z.enum(['sd-1', 'sd-2']);
|
||||||
|
|
||||||
|
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for model parameter
|
* Zod schema for model parameter
|
||||||
* TODO: Make this a dynamically generated enum?
|
* TODO: Make this a dynamically generated enum?
|
||||||
*/
|
*/
|
||||||
const zModel = z.object({
|
export const zMainModel = z.object({
|
||||||
id: z.string(),
|
id: z.string(),
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
type: z.string(),
|
base_model: zBaseModel,
|
||||||
base_model: z.string(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type ModelParam = z.infer<typeof zModel> | null;
|
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||||
/**
|
|
||||||
* Zod schema for VAE parameter
|
|
||||||
* TODO: Make this a dynamically generated enum?
|
|
||||||
*/
|
|
||||||
export const zVAE = z.string();
|
|
||||||
/**
|
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
|
||||||
*/
|
|
||||||
export type VAEParam = z.infer<typeof zVAE>;
|
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
export const isValidModel = (val: unknown): val is ModelParam =>
|
export const isValidMainModel = (val: unknown): val is MainModelParam =>
|
||||||
zModel.safeParse(val).success;
|
zMainModel.safeParse(val).success;
|
||||||
|
/**
|
||||||
|
* Zod schema for VAE parameter
|
||||||
|
*/
|
||||||
|
export const zVaeModel = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
name: z.string(),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type VaeModelParam = z.infer<typeof zVaeModel>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a model parameter
|
||||||
|
*/
|
||||||
|
export const isValidVaeModel = (val: unknown): val is VaeModelParam =>
|
||||||
|
zVaeModel.safeParse(val).success;
|
||||||
|
/**
|
||||||
|
* Zod schema for LoRA
|
||||||
|
*/
|
||||||
|
export const zLoRAModel = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
name: z.string(),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a model parameter
|
||||||
|
*/
|
||||||
|
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||||
|
zLoRAModel.safeParse(val).success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
|
@ -6,9 +6,9 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
|||||||
|
|
||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import { forEach, isString } from 'lodash-es';
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { modelSelected } from '../../parameters/store/actions';
|
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
@ -63,6 +63,16 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
if (isLoading) {
|
||||||
|
// return early here to avoid resetting model selection before we've loaded the available models
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectedModel && mainModels?.ids.includes(selectedModel?.id)) {
|
||||||
|
// the selected model is an available model, no need to change it
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const firstModel = mainModels?.ids[0];
|
const firstModel = mainModels?.ids[0];
|
||||||
|
|
||||||
if (!isString(firstModel)) {
|
if (!isString(firstModel)) {
|
||||||
@ -70,7 +80,7 @@ const ModelSelect = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
handleChangeModel(firstModel);
|
handleChangeModel(firstModel);
|
||||||
}, [handleChangeModel, mainModels?.ids]);
|
}, [handleChangeModel, isLoading, mainModels?.ids, selectedModel]);
|
||||||
|
|
||||||
return isLoading ? (
|
return isLoading ? (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
|
@ -9,9 +9,10 @@ import { forEach } from 'lodash-es';
|
|||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
|
import { zVaeModel } from 'features/parameters/store/parameterZodSchemas';
|
||||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from '../../../common/components/IAIMantineSelectItemWithTooltip';
|
|
||||||
|
|
||||||
const VAESelect = () => {
|
const VAESelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -34,8 +35,8 @@ const VAESelect = () => {
|
|||||||
|
|
||||||
const data: SelectItem[] = [
|
const data: SelectItem[] = [
|
||||||
{
|
{
|
||||||
value: 'auto',
|
value: 'default',
|
||||||
label: 'Automatic',
|
label: 'Default',
|
||||||
group: 'Default',
|
group: 'Default',
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
@ -45,50 +46,65 @@ const VAESelect = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const disabled = currentMainModel?.base_model !== model.base_model;
|
||||||
|
|
||||||
data.push({
|
data.push({
|
||||||
value: id,
|
value: id,
|
||||||
label: model.name,
|
label: model.name,
|
||||||
group: MODEL_TYPE_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
...(currentMainModel?.base_model !== model.base_model
|
disabled,
|
||||||
? { disabled: true, tooltip: 'Incompatible base model' }
|
tooltip: disabled
|
||||||
: {}),
|
? `Incompatible base model: ${model.base_model}`
|
||||||
|
: undefined,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
|
||||||
}, [vaeModels, currentMainModel?.base_model]);
|
}, [vaeModels, currentMainModel?.base_model]);
|
||||||
|
|
||||||
const selectedVaeModel = useMemo(
|
const selectedVaeModel = useMemo(
|
||||||
() => vaeModels?.entities[selectedVae],
|
() => (selectedVae?.id ? vaeModels?.entities[selectedVae?.id] : null),
|
||||||
[vaeModels?.entities, selectedVae]
|
[vaeModels?.entities, selectedVae]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v || v === 'default') {
|
||||||
|
dispatch(vaeSelected(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(vaeSelected(v));
|
|
||||||
|
const [base_model, type, name] = v.split('/');
|
||||||
|
|
||||||
|
const model = zVaeModel.parse({
|
||||||
|
id: v,
|
||||||
|
name,
|
||||||
|
base_model,
|
||||||
|
});
|
||||||
|
|
||||||
|
dispatch(vaeSelected(model));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedVae && vaeModels?.ids.includes(selectedVae)) {
|
if (selectedVae && vaeModels?.ids.includes(selectedVae.id)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
handleChangeModel('auto');
|
dispatch(vaeSelected(null));
|
||||||
}, [handleChangeModel, vaeModels?.ids, selectedVae]);
|
}, [handleChangeModel, vaeModels?.ids, selectedVae, dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
tooltip={selectedVaeModel?.description}
|
tooltip={selectedVaeModel?.description}
|
||||||
label={t('modelManager.vae')}
|
label={t('modelManager.vae')}
|
||||||
value={selectedVae}
|
value={selectedVae?.id ?? 'default'}
|
||||||
placeholder="Pick one"
|
placeholder="Default"
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
clearable
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user