mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Enable 2.x Model Generation in Linear UI
This commit is contained in:
parent
bf0577c882
commit
61c426f502
@ -40,6 +40,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
|||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
|
currentModelType,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -66,7 +67,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
|||||||
// Create the model loader node
|
// Create the model loader node
|
||||||
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
type: 'sd1_model_loader',
|
type: currentModelType,
|
||||||
model_name: model,
|
model_name: model,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ const ITERATE = 'iterate';
|
|||||||
export const buildTextToImageGraph = (state: RootState): Graph => {
|
export const buildTextToImageGraph = (state: RootState): Graph => {
|
||||||
const {
|
const {
|
||||||
model,
|
model,
|
||||||
|
currentModelType,
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
@ -50,7 +51,7 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
|
|||||||
// Create the model loader node
|
// Create the model loader node
|
||||||
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
type: 'sd1_model_loader',
|
type: currentModelType,
|
||||||
model_name: model,
|
model_name: model,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { Scheduler } from 'app/constants';
|
import { Scheduler } from 'app/constants';
|
||||||
|
import { ModelLoaderTypes } from 'features/system/components/ModelSelect';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp, sortBy } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
@ -49,6 +50,7 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
|
currentModelType: ModelLoaderTypes;
|
||||||
shouldUseSeamless: boolean;
|
shouldUseSeamless: boolean;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
@ -83,6 +85,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: '',
|
model: '',
|
||||||
|
currentModelType: 'sd1_model_loader',
|
||||||
shouldUseSeamless: false,
|
shouldUseSeamless: false,
|
||||||
seamlessXAxis: true,
|
seamlessXAxis: true,
|
||||||
seamlessYAxis: true,
|
seamlessYAxis: true,
|
||||||
@ -217,6 +220,9 @@ export const generationSlice = createSlice({
|
|||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
|
setCurrentModelType: (state, action: PayloadAction<ModelLoaderTypes>) => {
|
||||||
|
state.currentModelType = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(getModels.fulfilled, (state, action) => {
|
builder.addCase(getModels.fulfilled, (state, action) => {
|
||||||
@ -277,6 +283,7 @@ export const {
|
|||||||
setVerticalSymmetrySteps,
|
setVerticalSymmetrySteps,
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelSelected,
|
modelSelected,
|
||||||
|
setCurrentModelType,
|
||||||
setShouldUseNoiseSettings,
|
setShouldUseNoiseSettings,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
setSeamlessXAxis,
|
setSeamlessXAxis,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useEffect } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
@ -9,7 +9,11 @@ import IAIMantineSelect, {
|
|||||||
IAISelectDataType,
|
IAISelectDataType,
|
||||||
} from 'common/components/IAIMantineSelect';
|
} from 'common/components/IAIMantineSelect';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import {
|
||||||
|
modelSelected,
|
||||||
|
setCurrentModelType,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
selectAllSD1Models,
|
selectAllSD1Models,
|
||||||
selectByIdSD1Models,
|
selectByIdSD1Models,
|
||||||
@ -55,12 +59,28 @@ export const modelSelector = createSelector(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader';
|
||||||
|
|
||||||
|
const MODEL_LOADER_MAP = {
|
||||||
|
'sd-1': 'sd1_model_loader',
|
||||||
|
'sd-2': 'sd2_model_loader',
|
||||||
|
};
|
||||||
|
|
||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { selectedModel, sd1ModelData, sd2ModelData } =
|
const { selectedModel, sd1ModelData, sd2ModelData } =
|
||||||
useAppSelector(modelSelector);
|
useAppSelector(modelSelector);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModel)
|
||||||
|
dispatch(
|
||||||
|
setCurrentModelType(
|
||||||
|
MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}, [dispatch, selectedModel]);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
|
Loading…
Reference in New Issue
Block a user