mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
change model store to object, update main model and vae dropdowns
This commit is contained in:
parent
909fe047e4
commit
6356dc335f
@ -16,10 +16,12 @@ export const addVAEToGraph = (
|
|||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
state: RootState
|
state: RootState
|
||||||
): void => {
|
): void => {
|
||||||
const { vae: vaeId } = state.generation;
|
const { vae } = state.generation;
|
||||||
const vae_model = modelIdToVAEModelField(vaeId);
|
const vae_model = modelIdToVAEModelField(vae?.id || '');
|
||||||
|
|
||||||
if (vaeId !== 'auto') {
|
const isAutoVae = vae?.id === 'auto';
|
||||||
|
|
||||||
|
if (!isAutoVae) {
|
||||||
graph.nodes[VAE_LOADER] = {
|
graph.nodes[VAE_LOADER] = {
|
||||||
type: 'vae_loader',
|
type: 'vae_loader',
|
||||||
id: VAE_LOADER,
|
id: VAE_LOADER,
|
||||||
@ -30,7 +32,7 @@ export const addVAEToGraph = (
|
|||||||
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -43,7 +45,7 @@ export const addVAEToGraph = (
|
|||||||
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -56,7 +58,7 @@ export const addVAEToGraph = (
|
|||||||
if (graph.id === INPAINT_GRAPH) {
|
if (graph.id === INPAINT_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -36,7 +36,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: modelId,
|
model: currentModel,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -50,7 +50,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToMainModelField(modelId);
|
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
@ -35,7 +35,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: modelId,
|
model: currentModel,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -59,7 +59,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
const model = modelIdToMainModelField(modelId);
|
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
|
@ -25,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: modelId,
|
model: currentModel,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -38,7 +38,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToMainModelField(modelId);
|
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
@ -38,7 +38,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: modelId,
|
model: currentModel,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -73,7 +73,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = modelIdToMainModelField(modelId);
|
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
|
@ -22,7 +22,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: modelId,
|
model: currentModel,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -31,7 +31,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
clipSkip,
|
clipSkip,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const model = modelIdToMainModelField(modelId);
|
const model = modelIdToMainModelField(currentModel?.id || '');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
@ -16,7 +16,6 @@ import {
|
|||||||
SeedParam,
|
SeedParam,
|
||||||
StepsParam,
|
StepsParam,
|
||||||
StrengthParam,
|
StrengthParam,
|
||||||
VAEParam,
|
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
|
|
||||||
@ -50,7 +49,7 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
vae: VAEParam;
|
vae: ModelParam;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
clipSkip: number;
|
clipSkip: number;
|
||||||
@ -84,8 +83,8 @@ export const initialGenerationState: GenerationState = {
|
|||||||
shouldUseSymmetry: false,
|
shouldUseSymmetry: false,
|
||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: '',
|
model: null,
|
||||||
vae: '',
|
vae: null,
|
||||||
seamlessXAxis: false,
|
seamlessXAxis: false,
|
||||||
seamlessYAxis: false,
|
seamlessYAxis: false,
|
||||||
clipSkip: 0,
|
clipSkip: 0,
|
||||||
@ -216,16 +215,17 @@ export const generationSlice = createSlice({
|
|||||||
state.initialImage = { imageName: image_name, width, height };
|
state.initialImage = { imageName: image_name, width, height };
|
||||||
},
|
},
|
||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
const [base_model, type, name] = action.payload.split('/');
|
||||||
|
|
||||||
// Clamp ClipSkip Based On Selected Model
|
// Clamp ClipSkip Based On Selected Model
|
||||||
const clipSkipMax =
|
const { maxClip } = clipSkipMap[base_model as keyof typeof clipSkipMap];
|
||||||
clipSkipMap[action.payload.split('/')[0] as keyof typeof clipSkipMap]
|
state.clipSkip = clamp(state.clipSkip, 0, maxClip);
|
||||||
.maxClip;
|
|
||||||
state.clipSkip = clamp(state.clipSkip, 0, clipSkipMax);
|
state.model = { id: action.payload, base_model, name, type };
|
||||||
},
|
},
|
||||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.vae = action.payload;
|
const [base_model, type, name] = action.payload.split('/');
|
||||||
|
state.vae = { id: action.payload, base_model, name, type };
|
||||||
},
|
},
|
||||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||||
state.clipSkip = action.payload;
|
state.clipSkip = action.payload;
|
||||||
@ -235,7 +235,13 @@ export const generationSlice = createSlice({
|
|||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
state.model = defaultModel;
|
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||||
|
state.model = {
|
||||||
|
id: defaultModel,
|
||||||
|
name: model_name,
|
||||||
|
type: model_type,
|
||||||
|
base_model: base_model,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
builder.addCase(setShouldShowAdvancedOptions, (state, action) => {
|
||||||
|
@ -130,20 +130,21 @@ export const isValidHeight = (val: unknown): val is HeightParam =>
|
|||||||
* Zod schema for model parameter
|
* Zod schema for model parameter
|
||||||
* TODO: Make this a dynamically generated enum?
|
* TODO: Make this a dynamically generated enum?
|
||||||
*/
|
*/
|
||||||
export const zModel = z.string();
|
const zModel = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
name: z.string(),
|
||||||
|
type: z.string(),
|
||||||
|
base_model: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type ModelParam = z.infer<typeof zModel>;
|
export type ModelParam = z.infer<typeof zModel> | null;
|
||||||
/**
|
|
||||||
* Zod schema for VAE parameter
|
|
||||||
* TODO: Make this a dynamically generated enum?
|
|
||||||
*/
|
|
||||||
export const zVAE = z.string();
|
|
||||||
/**
|
/**
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type VAEParam = z.infer<typeof zVAE>;
|
export type VAEParam = z.infer<typeof zModel> | null;
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
|
@ -19,7 +19,7 @@ const ModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const selectedModelId = useAppSelector(
|
const currentModel = useAppSelector(
|
||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -48,8 +48,8 @@ const ModelSelect = () => {
|
|||||||
}, [mainModels]);
|
}, [mainModels]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => mainModels?.entities[selectedModelId],
|
() => mainModels?.entities[currentModel?.id || ''],
|
||||||
[mainModels?.entities, selectedModelId]
|
[mainModels?.entities, currentModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -63,10 +63,6 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstModel = mainModels?.ids[0];
|
const firstModel = mainModels?.ids[0];
|
||||||
|
|
||||||
if (!isString(firstModel)) {
|
if (!isString(firstModel)) {
|
||||||
@ -74,7 +70,7 @@ const ModelSelect = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
handleChangeModel(firstModel);
|
handleChangeModel(firstModel);
|
||||||
}, [handleChangeModel, mainModels?.ids, selectedModelId]);
|
}, [handleChangeModel, mainModels?.ids]);
|
||||||
|
|
||||||
return isLoading ? (
|
return isLoading ? (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
@ -87,7 +83,7 @@ const ModelSelect = () => {
|
|||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModelId}
|
value={selectedModel?.id}
|
||||||
placeholder={data.length > 0 ? 'Select a model' : 'No models detected!'}
|
placeholder={data.length > 0 ? 'Select a model' : 'No models detected!'}
|
||||||
data={data}
|
data={data}
|
||||||
error={data.length === 0}
|
error={data.length === 0}
|
||||||
|
@ -18,7 +18,7 @@ const VAESelect = () => {
|
|||||||
|
|
||||||
const { data: vaeModels } = useGetVaeModelsQuery();
|
const { data: vaeModels } = useGetVaeModelsQuery();
|
||||||
|
|
||||||
const selectedModelId = useAppSelector(
|
const currentModel = useAppSelector(
|
||||||
(state: RootState) => state.generation.vae
|
(state: RootState) => state.generation.vae
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -51,8 +51,8 @@ const VAESelect = () => {
|
|||||||
}, [vaeModels]);
|
}, [vaeModels]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => vaeModels?.entities[selectedModelId],
|
() => vaeModels?.entities[currentModel?.id || ''],
|
||||||
[vaeModels?.entities, selectedModelId]
|
[vaeModels?.entities, currentModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -66,17 +66,17 @@ const VAESelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
|
if (currentModel?.id && vaeModels?.ids.includes(currentModel?.id)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
handleChangeModel('auto');
|
handleChangeModel('auto');
|
||||||
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
|
}, [handleChangeModel, vaeModels?.ids, currentModel?.id]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.vae')}
|
label={t('modelManager.vae')}
|
||||||
value={selectedModelId}
|
value={currentModel?.id}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
|
Loading…
Reference in New Issue
Block a user