From 339e7ce2136c8e2bbe0a707b9e22451e76c05421 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Jun 2023 17:48:57 +1000 Subject: [PATCH] feat(ui): initial implementation of model loading - Update model listing code to use `rtk-query` - Update all graph generation to use new `pipeline_model_loader` node --- .../frontend/web/src/app/components/App.tsx | 13 +++ .../enhancers/reduxRemember/serialize.ts | 3 - .../enhancers/reduxRemember/unserialize.ts | 4 - .../listeners/socketio/socketConnected.ts | 12 +-- invokeai/frontend/web/src/app/store/store.ts | 9 +- .../fields/ModelInputFieldComponent.tsx | 98 +++++++++++++++---- .../src/features/nodes/store/nodesSlice.ts | 15 --- .../buildCanvasImageToImageGraph.ts | 9 +- .../graphBuilders/buildCanvasInpaintGraph.ts | 9 +- .../buildCanvasTextToImageGraph.ts | 9 +- .../buildLinearImageToImageGraph.ts | 9 +- .../buildLinearTextToImageGraph.ts | 15 ++- .../util/graphBuilders/buildNodesGraph.ts | 9 +- .../nodes/util/graphBuilders/constants.ts | 2 +- .../nodes/util/modelIdToPipelineModelField.ts | 18 ++++ .../parameters/store/generationSlice.ts | 29 +----- .../parameters/store/parameterZodSchemas.ts | 14 +++ .../system/components/ModelSelect.tsx | 85 +++++++++++----- .../system/hooks/useIsApplicationReady.ts | 11 +-- .../features/system/store/modelSelectors.ts | 59 ----------- .../store/models/sd1PipelineModelSlice.ts | 56 ----------- .../store/models/sd2PipelineModelSlice.ts | 56 ----------- .../system/store/modelsPersistDenylist.ts | 9 -- .../src/features/system/store/systemSlice.ts | 8 -- .../frontend/web/src/services/apiSlice.ts | 48 ++++++++- .../frontend/web/src/services/thunks/model.ts | 58 ----------- 26 files changed, 281 insertions(+), 386 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts delete mode 100644 invokeai/frontend/web/src/features/system/store/modelSelectors.ts delete mode 100644 invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts delete mode 100644 invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts delete mode 100644 invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts delete mode 100644 invokeai/frontend/web/src/services/thunks/model.ts diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index a11d8d048c..55fcc97745 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -24,6 +24,7 @@ import Toaster from './Toaster'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { useListModelsQuery } from 'services/apiSlice'; const DEFAULT_CONFIG = {}; @@ -46,6 +47,18 @@ const App = ({ const isApplicationReady = useIsApplicationReady(); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + const { data: controlnetModels } = useListModelsQuery({ + model_type: 'controlnet', + }); + const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' }); + const { data: loraModels } = useListModelsQuery({ model_type: 'lora' }); + const { data: embeddingModels } = useListModelsQuery({ + model_type: 'embedding', + }); + const [loadingOverridden, setLoadingOverridden] = useState(false); const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index e498ecb749..cb18d48301 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; -import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { omit } from 'lodash-es'; @@ -18,8 +17,6 @@ const serializationDenylist: { gallery: galleryPersistDenylist, generation: generationPersistDenylist, lightbox: lightboxPersistDenylist, - sd1models: modelsPersistDenylist, - sd2models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 649b56316d..8f40b0bb59 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,8 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice'; -import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -22,8 +20,6 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - sd1PipelineModels: sd1InitialPipelineModelsState, - sd2PipelineModels: sd2InitialPipelineModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 0893066f1f..bf54e63836 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,7 +1,6 @@ import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; import { startAppListening } from '../..'; @@ -15,8 +14,7 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } = - getState(); + const { nodes, config, images } = getState(); const { disabledTabs } = config; @@ -29,14 +27,6 @@ export const addSocketConnectedEventListener = () => { ); } - if (!sd1pipelinemodels.ids.length) { - dispatch(receivedModels({ baseModel: 'sd-1', modelType: 'pipeline' })); - } - - if (!sd2pipelinemodels.ids.length) { - dispatch(receivedModels({ baseModel: 'sd-2', modelType: 'pipeline' })); - } - if (!nodes.schema && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 8489de85f0..57a97168a3 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -28,11 +28,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; - -// Model Reducers -import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice'; -import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice'; - import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; @@ -43,8 +38,6 @@ const allReducers = { gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - sd1pipelinemodels: sd1PipelineModelReducer, - sd2pipelinemodels: sd2PipelineModelReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, @@ -54,8 +47,8 @@ const allReducers = { images: imagesReducer, controlNet: controlNetReducer, boards: boardsReducer, - [api.reducerPath]: api.reducer, // session: sessionReducer, + [api.reducerPath]: api.reducer, }; const rootReducer = combineReducers(allReducers); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 480c8591bb..c274f23f26 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -1,14 +1,18 @@ -import { NativeSelect } from '@mantine/core'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { ModelInputFieldTemplate, ModelInputFieldValue, } from 'features/nodes/types/types'; -import { modelSelector } from 'features/system/store/modelSelectors'; -import { ChangeEvent, memo } from 'react'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { FieldComponentProps } from './types'; +import { forEach, isString } from 'lodash-es'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useTranslation } from 'react-i18next'; +import { useListModelsQuery } from 'services/apiSlice'; const ModelInputFieldComponent = ( props: FieldComponentProps @@ -16,26 +20,82 @@ const ModelInputFieldComponent = ( const { nodeId, field } = props; const dispatch = useAppDispatch(); + const { t } = useTranslation(); - const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } = - useAppSelector(modelSelector); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], + [pipelineModels?.entities, pipelineModels?.ids, field.value] + ); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && pipelineModels?.ids.includes(field.value)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleValueChanged(firstModel); + }, [field.value, handleValueChanged, pipelineModels?.ids]); return ( - + /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5425d1cfd5..341f0c467b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -101,21 +101,6 @@ const nodesSlice = createSlice({ builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { state.schema = action.payload; }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - state.nodes.forEach((node) => { - forEach(node.data.inputs, (input) => { - if (input.type === 'image') { - if (input.value?.image_name === image_name) { - input.value.image_url = image_url; - input.value.thumbnail_url = thumbnail_url; - } - } - }); - }); - }); }, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index efaeaddff2..ccdc3e0a27 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -23,6 +23,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 785e1d2fdb..9ffe85b3c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -17,6 +17,7 @@ import { INPAINT_GRAPH, INPAINT, } from './constants'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; + const model = modelIdToPipelineModelField(modelId); + const graph: NonNullableGraph = { id: INPAINT_GRAPH, nodes: { @@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = ( prompt: negativePrompt, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [RANGE_OF_SIZE]: { type: 'range_of_size', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index ca0e56e849..920cb5bf02 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -14,6 +14,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * Builds the Canvas tab's Text to Image graph. @@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 78a6623a16..8425ac043a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -22,6 +22,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = ( throw new Error('No initial image found in state'); } + const model = modelIdToPipelineModelField(modelId); + // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { id: IMAGE_TO_IMAGE_GRAPH, @@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index c179a89504..973acdfb77 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,6 +1,10 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; +import { + BaseModelType, + RandomIntInvocation, + RangeOfSizeInvocation, +} from 'services/api'; import { ITERATE, LATENTS_TO_IMAGE, @@ -14,6 +18,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; type TextToImageGraphOverrides = { width: number; @@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = ( shouldRandomizeSeed, } = state.generation; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 6a700d4813..072b1a53fd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,9 +1,10 @@ import { Graph } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; +import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; import { AnyInvocation } from 'services/events/types'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * We need to do special handling for some fields @@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'model') { + if (field.value) { + return modelIdToPipelineModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 39e0080d11..7d4469bc41 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,7 +7,7 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const MODEL_LOADER = 'model_loader'; +export const MODEL_LOADER = 'pipeline_model_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts new file mode 100644 index 0000000000..bbcd8d9bc6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts @@ -0,0 +1,18 @@ +import { BaseModelType, PipelineModelField } from 'services/api'; + +/** + * Crudely converts a model id to a pipeline model field + * TODO: Make better + */ +export const modelIdToPipelineModelField = ( + modelId: string +): PipelineModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: PipelineModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index e1de166b5c..e7dcbf0d83 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,12 +1,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { DEFAULT_SCHEDULER_NAME, Scheduler } from 'app/constants'; -import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; +import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; -import { clamp, sortBy } from 'lodash-es'; +import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api'; -import { imageUrlsReceived } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -50,7 +47,6 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; - currentModelType: ModelLoaderTypes; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -85,7 +81,6 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', - currentModelType: 'sd1_model_loader', shouldUseSeamless: false, seamlessXAxis: true, seamlessYAxis: true, @@ -221,33 +216,14 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, - setCurrentModelType: (state, action: PayloadAction) => { - state.currentModelType = action.payload; - }, }, extraReducers: (builder) => { - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (!state.model) { - const firstModel = sortBy(action.payload, 'name')[0]; - state.model = firstModel.name; - } - }); - builder.addCase(configChanged, (state, action) => { const defaultModel = action.payload.sd?.defaultModel; if (defaultModel && !state.model) { state.model = defaultModel; } }); - - // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - // const { image_name, image_url, thumbnail_url } = action.payload; - - // if (state.initialImage?.image_name === image_name) { - // state.initialImage.image_url = image_url; - // state.initialImage.thumbnail_url = thumbnail_url; - // } - // }); }, }); @@ -284,7 +260,6 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, - setCurrentModelType, setShouldUseNoiseSettings, setSeamless, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 61567d3fb8..48eb309e7d 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -154,3 +154,17 @@ export type StrengthParam = z.infer; */ export const isValidStrength = (val: unknown): val is StrengthParam => zStrength.safeParse(val).success; + +// /** +// * Zod schema for BaseModelType +// */ +// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); +// /** +// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. +// */ +// export type BaseModelType = z.infer; +// /** +// * Validates/type-guards a value as a base model type +// */ +// export const isValidBaseModelType = (val: unknown): val is BaseModelType => +// zBaseModelType.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 813bd9fb70..43de14d507 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,39 +1,58 @@ -import { memo, useCallback, useEffect } from 'react'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { - modelSelected, - setCurrentModelType, -} from 'features/parameters/store/generationSlice'; +import { modelSelected } from 'features/parameters/store/generationSlice'; -import { modelSelector } from '../store/modelSelectors'; +import { forEach, isString } from 'lodash-es'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; +import { useListModelsQuery } from 'services/apiSlice'; -export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; - -const MODEL_LOADER_MAP = { - 'sd-1': 'sd1_model_loader', - 'sd-2': 'sd2_model_loader', +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', }; const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { - selectedModel, - sd1PipelineModelDropDownData, - sd2PipelineModelDropdownData, - } = useAppSelector(modelSelector); - useEffect(() => { - if (selectedModel) - dispatch( - setCurrentModelType( - MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes - ) - ); - }, [dispatch, selectedModel]); + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[selectedModelId], + [pipelineModels?.entities, selectedModelId] + ); const handleChangeModel = useCallback( (v: string | null) => { @@ -45,13 +64,27 @@ const ModelSelect = () => { [dispatch] ); + useEffect(() => { + if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleChangeModel(firstModel); + }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + return ( ); diff --git a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts index 193420e29c..8ba5731a5b 100644 --- a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts +++ b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts @@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors'; const isApplicationReadySelector = createSelector( [systemSelector, configSelector], (system, config) => { - const { wereModelsReceived, wasSchemaParsed } = system; + const { wasSchemaParsed } = system; const { disabledTabs } = config; return { disabledTabs, - wereModelsReceived, wasSchemaParsed, }; } @@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector( * Checks if the application is ready to be used, i.e. if the initial startup process is finished. */ export const useIsApplicationReady = () => { - const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( + const { disabledTabs, wasSchemaParsed } = useAppSelector( isApplicationReadySelector ); const isApplicationReady = useMemo(() => { - if (!wereModelsReceived) { - return false; - } - if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { return false; } return true; - }, [disabledTabs, wereModelsReceived, wasSchemaParsed]); + }, [disabledTabs, wasSchemaParsed]); return isApplicationReady; }; diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts deleted file mode 100644 index b63c6d256c..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { IAISelectDataType } from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { isEqual } from 'lodash-es'; - -import { - selectAllSD1PipelineModels, - selectByIdSD1PipelineModels, -} from './models/sd1PipelineModelSlice'; - -import { - selectAllSD2PipelineModels, - selectByIdSD2PipelineModels, -} from './models/sd2PipelineModelSlice'; - -export const modelSelector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - let selectedModel = selectByIdSD1PipelineModels(state, generation.model); - if (selectedModel === undefined) - selectedModel = selectByIdSD2PipelineModels(state, generation.model); - - const sd1PipelineModels = selectAllSD1PipelineModels(state); - const sd2PipelineModels = selectAllSD2PipelineModels(state); - - const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels); - - const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '1.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '2.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - return { - selectedModel, - allPipelineModels, - sd1PipelineModels, - sd2PipelineModels, - sd1PipelineModelDropDownData, - sd2PipelineModelDropdownData, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); diff --git a/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts deleted file mode 100644 index 99f1514e6c..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd1PipelineModelSlice.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion1ModelCheckpointConfig, - StableDiffusion1ModelDiffusersConfig, -} from 'services/api'; - -import { receivedModels } from 'services/thunks/model'; - -export type SD1PipelineModel = ( - | StableDiffusion1ModelCheckpointConfig - | StableDiffusion1ModelDiffusersConfig -) & { - name: string; -}; - -export const sd1PipelineModelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd1InitialPipelineModelsState = - sd1PipelineModelsAdapter.getInitialState(); - -export type SD1PipelineModelState = typeof sd1InitialPipelineModelsState; - -export const sd1PipelineModelsSlice = createSlice({ - name: 'sd1PipelineModels', - initialState: sd1InitialPipelineModelsState, - reducers: { - modelAdded: sd1PipelineModelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-1') return; - sd1PipelineModelsAdapter.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD1PipelineModels, - selectById: selectByIdSD1PipelineModels, - selectEntities: selectEntitiesSD1PipelineModels, - selectIds: selectIdsSD1PipelineModels, - selectTotal: selectTotalSD1PipelineModels, -} = sd1PipelineModelsAdapter.getSelectors( - (state) => state.sd1pipelinemodels -); - -export const { modelAdded } = sd1PipelineModelsSlice.actions; - -export default sd1PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts b/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts deleted file mode 100644 index 69ff772222..0000000000 --- a/invokeai/frontend/web/src/features/system/store/models/sd2PipelineModelSlice.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - StableDiffusion2ModelCheckpointConfig, - StableDiffusion2ModelDiffusersConfig, -} from 'services/api'; - -import { receivedModels } from 'services/thunks/model'; - -export type SD2PipelineModel = ( - | StableDiffusion2ModelCheckpointConfig - | StableDiffusion2ModelDiffusersConfig -) & { - name: string; -}; - -export const sd2PipelineModelsAdapater = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const sd2InitialPipelineModelsState = - sd2PipelineModelsAdapater.getInitialState(); - -export type SD2PipelineModelState = typeof sd2InitialPipelineModelsState; - -export const sd2PipelineModelsSlice = createSlice({ - name: 'sd2PipelineModels', - initialState: sd2InitialPipelineModelsState, - reducers: { - modelAdded: sd2PipelineModelsAdapater.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (action.meta.arg.baseModel !== 'sd-2') return; - sd2PipelineModelsAdapater.setAll(state, action.payload); - }); - }, -}); - -export const { - selectAll: selectAllSD2PipelineModels, - selectById: selectByIdSD2PipelineModels, - selectEntities: selectEntitiesSD2PipelineModels, - selectIds: selectIdsSD2PipelineModels, - selectTotal: selectTotalSD2PipelineModels, -} = sd2PipelineModelsAdapater.getSelectors( - (state) => state.sd2pipelinemodels -); - -export const { modelAdded } = sd2PipelineModelsSlice.actions; - -export default sd2PipelineModelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts deleted file mode 100644 index 417a399cf2..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { SD1PipelineModelState } from './models/sd1PipelineModelSlice'; -import { SD2PipelineModelState } from './models/sd2PipelineModelSlice'; - -/** - * Models slice persist denylist - */ -export const modelsPersistDenylist: - | (keyof SD1PipelineModelState)[] - | (keyof SD2PipelineModelState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 8a148ca38b..688f69c1f7 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -20,7 +20,6 @@ import { } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { imageUploaded } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; import { makeToast } from '../../../app/components/Toaster'; import { LANGUAGES } from '../components/LanguagePicker'; @@ -377,13 +376,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Received available models from the backend - */ - builder.addCase(receivedModels.fulfilled, (state) => { - state.wereModelsReceived = true; - }); - /** * OpenAPI schema was parsed */ diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts index 2d42931b0b..e2d765dd90 100644 --- a/invokeai/frontend/web/src/services/apiSlice.ts +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -13,23 +13,68 @@ import { TagTypesFrom, TagTypesFromApi, } from '@reduxjs/toolkit/dist/query/endpointDefinitions'; +import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; +import { BaseModelType } from './api/models/BaseModelType'; +import { ModelType } from './api/models/ModelType'; +import { ModelsList } from './api/models/ModelsList'; +import { keyBy } from 'lodash-es'; type ListBoardsArg = { offset: number; limit: number }; type UpdateBoardArg = { board_id: string; changes: BoardChanges }; type AddImageToBoardArg = { board_id: string; image_name: string }; type RemoveImageFromBoardArg = { board_id: string; image_name: string }; type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; +type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType }; -const tagTypes = ['Board', 'Image']; +type ModelConfig = ModelsList['models'][number]; + +const tagTypes = ['Board', 'Image', 'Model']; type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; const LIST = 'LIST'; +const modelsAdapter = createEntityAdapter({ + selectId: (model) => getModelId(model), + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +const getModelId = ({ base_model, type, name }: ModelConfig) => + `${base_model}/${type}/${name}`; + export const api = createApi({ baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), reducerPath: 'api', tagTypes, endpoints: (build) => ({ + /** + * Models Queries + */ + + listModels: build.query, ListModelsArg>({ + query: (arg) => ({ url: 'models/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.ids.map((id) => ({ + type: 'Model' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: (response: ModelsList, meta, arg) => { + return modelsAdapter.addMany( + modelsAdapter.getInitialState(), + keyBy(response.models, getModelId) + ); + }, + }), /** * Boards Queries */ @@ -174,4 +219,5 @@ export const { useRemoveImageFromBoardMutation, useListBoardImagesQuery, useGetImageDTOQuery, + useListModelsQuery, } = api; diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts deleted file mode 100644 index 619aa4b7b2..0000000000 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SD1PipelineModel } from 'features/system/store/models/sd1PipelineModelSlice'; -import { SD2PipelineModel } from 'features/system/store/models/sd2PipelineModelSlice'; -import { reduce, size } from 'lodash-es'; -import { BaseModelType, ModelType, ModelsService } from 'services/api'; - -const models = log.child({ namespace: 'model' }); - -export const IMAGES_PER_PAGE = 20; - -type receivedModelsArg = { - baseModel: BaseModelType | undefined; - modelType: ModelType | undefined; -}; - -export const receivedModels = createAppAsyncThunk( - 'models/receivedModels', - async (arg: receivedModelsArg) => { - const response = await ModelsService.listModels(arg); - - let deserializedModels = {}; - - if (arg.baseModel === undefined) return response.models; - if (arg.modelType === undefined) return response.models; - - if (arg.baseModel === 'sd-1') { - deserializedModels = reduce( - response.models[arg.baseModel][arg.modelType], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - return modelsAccumulator; - }, - {} as Record - ); - } - - if (arg.baseModel === 'sd-2') { - deserializedModels = reduce( - response.models[arg.baseModel][arg.modelType], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - return modelsAccumulator; - }, - {} as Record - ); - } - - models.info( - { response }, - `Received ${size(response.models[arg.baseModel][arg.modelType])} ${[ - arg.baseModel, - ]} models` - ); - - return deserializedModels; - } -);