diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index aff409e9e5..21fd09e051 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,22 +1,20 @@ import io from typing import Optional +from PIL import Image from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse from fastapi.routing import APIRouter -from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel from invokeai.app.invocations.metadata import ImageMetadata from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.services.image_record_storage import OffsetPaginatedResults -from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.models.image_record import ( ImageDTO, ImageRecordChanges, ImageUrlsDTO, ) - from ..dependencies import ApiDependencies images_router = APIRouter(prefix="/v1/images", tags=["images"]) @@ -152,8 +150,9 @@ async def get_image_metadata( raise HTTPException(status_code=404) -@images_router.get( +@images_router.api_route( "/i/{image_name}/full", + methods=["GET", "HEAD"], operation_id="get_image_full", response_class=Response, responses={ diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index ee2e3b8076..67e053db27 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR. """ from __future__ import annotations -import os import hashlib +import os import textwrap -import yaml +import types from dataclasses import dataclass from pathlib import Path -from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree, move +from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable import torch +import yaml from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig - from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger @@ -259,6 +259,7 @@ from .models import ( ModelNotFoundException, InvalidModelException, DuplicateModelException, + ModelBase, ) # We are only starting to number the config file with release 3. @@ -361,7 +362,7 @@ class ModelManager(object): if model_key.startswith("_"): continue model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # alias for config file model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) @@ -381,18 +382,24 @@ class ModelManager(object): # causing otherwise unreferenced models to be removed from memory self._read_models() - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: + def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: """ - Given a model name, returns True if it is a valid - identifier. + Given a model name, returns True if it is a valid identifier. + + :param model_name: symbolic name of the model in models.yaml + :param model_type: ModelType enum indicating the type of model to return + :param base_model: BaseModelType enum indicating the base model used by this model + :param rescan: if True, scan_models_directory """ model_key = self.create_key(model_name, base_model, model_type) - return model_key in self.models + exists = model_key in self.models + + # if model not found try to find it (maybe file just pasted) + if rescan and not exists: + self.scan_models_directory(base_model=base_model, model_type=model_type) + exists = self.model_exists(model_name, base_model, model_type, rescan=False) + + return exists @classmethod def create_key( @@ -443,39 +450,32 @@ class ModelManager(object): :param model_name: symbolic name of the model in models.yaml :param model_type: ModelType enum indicating the type of model to return :param base_model: BaseModelType enum indicating the base model used by this model - :param submode_typel: an ModelType enum indicating the portion of + :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = MODEL_CLASSES[base_model][model_type] model_key = self.create_key(model_name, base_model, model_type) - # if model not found try to find it (maybe file just pasted) - if model_key not in self.models: - self.scan_models_directory(base_model=base_model, model_type=model_type) - if model_key not in self.models: - raise ModelNotFoundException(f"Model not found - {model_key}") + if not self.model_exists(model_name, base_model, model_type, rescan=True): + raise ModelNotFoundException(f"Model not found - {model_key}") - model_config = self.models[model_key] - model_path = self.resolve_model_path(model_config.path) + model_config = self._get_model_config(base_model, model_name, model_type) + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + model_type = submodel_type + submodel_type = None + + model_class = self._get_implementation(base_model, model_type) if not model_path.exists(): if model_class.save_to_config: self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found') + raise Exception(f'Files for model "{model_key}" not found at {model_path}') else: self.models.pop(model_key, None) - raise ModelNotFoundException(f"Model not found - {model_key}") - - # vae/movq override - # TODO: - if submodel_type is not None and hasattr(model_config, submodel_type): - override_path = getattr(model_config, submodel_type) - if override_path: - model_path = self.resolve_path(override_path) - model_type = submodel_type - submodel_type = None - model_class = MODEL_CLASSES[base_model][model_type] + raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') # TODO: path # TODO: is it accurate to use path as id @@ -513,6 +513,55 @@ class ModelManager(object): _cache=self.cache, ) + def _get_model_path( + self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None + ) -> (Path, bool): + """Extract a model's filesystem path from its config. + + :return: The fully qualified Path of the module (or submodule). + """ + model_path = model_config.path + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self.resolve_model_path(model_path) + return model_path, is_submodel_override + + def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: + """Get a model's config object.""" + model_key = self.create_key(model_name, base_model, model_type) + try: + model_config = self.models[model_key] + except KeyError: + raise ModelNotFoundException(f"Model not found - {model_key}") + return model_config + + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + """Get the concrete implementation class for a specific model type.""" + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + + def _instantiate( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel_type: Optional[SubModelType] = None, + ) -> ModelBase: + """Make a new instance of this model, without loading it.""" + model_config = self._get_model_config(base_model, model_name, model_type) + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + # FIXME: do non-overriden submodels get the right class? + constructor = self._get_implementation(base_model, model_type) + instance = constructor(model_path, base_model, model_type) + return instance + def model_info( self, model_name: str, @@ -661,7 +710,7 @@ class ModelManager(object): if path := model_attributes.get("path"): model_attributes["path"] = str(self.relative_model_path(Path(path))) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) @@ -852,7 +901,7 @@ class ModelManager(object): for model_key, model_config in self.models.items(): model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) if model_class.save_to_config: # TODO: or exclude_unset better fits here? data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) @@ -910,7 +959,7 @@ class ModelManager(object): model_path = self.resolve_model_path(model_config.path).absolute() if not model_path.exists(): - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) if model_class.save_to_config: model_config.error = ModelError.NotFound self.models.pop(model_key, None) @@ -926,7 +975,7 @@ class ModelManager(object): for cur_model_type in ModelType: if model_type is not None and cur_model_type != model_type: continue - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) if not models_dir.exists(): diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index b15844bcf8..957a102ffb 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,9 +1,14 @@ import os -import torch -import safetensors from enum import Enum from pathlib import Path -from typing import Optional, Union, Literal +from typing import Optional + +import safetensors +import torch +from diffusers.utils import is_safetensors_available +from omegaconf import OmegaConf + +from invokeai.app.services.config import InvokeAIAppConfig from .base import ( ModelBase, ModelConfigBase, @@ -18,9 +23,6 @@ from .base import ( InvalidModelException, ModelNotFoundException, ) -from invokeai.app.services.config import InvokeAIAppConfig -from diffusers.utils import is_safetensors_available -from omegaconf import OmegaConf class VaeModelFormat(str, Enum): @@ -80,7 +82,7 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): if not os.path.exists(path): - raise ModelNotFoundException() + raise ModelNotFoundException(f"Does not exist as local file: {path}") if os.path.isdir(path): if os.path.exists(os.path.join(path, "config.json")): diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index b38790e0c9..827424fa7f 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -96,7 +96,8 @@ export type AppFeature = | 'consoleLogging' | 'dynamicPrompting' | 'batches' - | 'syncModels'; + | 'syncModels' + | 'multiselect'; /** * A disable-able Stable Diffusion feature diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts index b59a2f3d6f..a162c6788d 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts @@ -9,6 +9,7 @@ import { useListImagesQuery } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { selectionChanged } from '../store/gallerySlice'; import { imagesSelectors } from 'services/api/util'; +import { useFeatureStatus } from '../../system/hooks/useFeatureStatus'; const selector = createSelector( [stateSelector, selectListImagesBaseQueryArgs], @@ -33,11 +34,18 @@ export const useMultiselect = (imageDTO?: ImageDTO) => { }), }); + const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled; + const handleClick = useCallback( (e: MouseEvent) => { if (!imageDTO) { return; } + if (!isMultiSelectEnabled) { + dispatch(selectionChanged([imageDTO])); + return; + } + if (e.shiftKey) { const rangeEndImageName = imageDTO.image_name; const lastSelectedImage = selection[selection.length - 1]?.image_name; @@ -71,7 +79,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => { dispatch(selectionChanged([imageDTO])); } }, - [dispatch, imageDTO, imageDTOs, selection] + [dispatch, imageDTO, imageDTOs, selection, isMultiSelectEnabled] ); const isSelected = useMemo( diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx index e212efbfa2..c2edd94106 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -31,7 +31,7 @@ const ParamLoraCollapse = () => { } return ( - + diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx index 835c315e5c..f10084e585 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -1,3 +1,4 @@ +import { Divider } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; @@ -8,20 +9,21 @@ import ParamLora from './ParamLora'; const selector = createSelector( stateSelector, ({ lora }) => { - const { loras } = lora; - - return { loras }; + return { lorasArray: map(lora.loras) }; }, defaultSelectorOptions ); const ParamLoraList = () => { - const { loras } = useAppSelector(selector); + const { lorasArray } = useAppSelector(selector); return ( <> - {map(loras, (lora) => ( - + {lorasArray.map((lora, i) => ( + <> + {i > 0 && } + + ))} ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index bc0bfee8fd..cdd91d6e4f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -9,7 +9,6 @@ import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, - ONNX_MODEL_LOADER, METADATA_ACCUMULATOR, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING, @@ -36,15 +35,11 @@ export const addLoRAsToGraph = ( | undefined; if (loraCount > 0) { - // Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs + // Remove modelLoaderNodeId unet connection to feed it to LoRAs graph.edges = graph.edges.filter( (e) => !( - e.source.node_id === MAIN_MODEL_LOADER && - ['unet'].includes(e.source.field) - ) && - !( - e.source.node_id === ONNX_MODEL_LOADER && + e.source.node_id === modelLoaderNodeId && ['unet'].includes(e.source.field) ) ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts new file mode 100644 index 0000000000..c0f7f7ca82 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts @@ -0,0 +1,212 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { + MetadataAccumulatorInvocation, + SDXLLoraLoaderInvocation, +} from 'services/api/types'; +import { + LORA_LOADER, + METADATA_ACCUMULATOR, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, + SDXL_MODEL_LOADER, +} from './constants'; + +export const addSDXLLoRAsToGraph = ( + state: RootState, + graph: NonNullableGraph, + baseNodeId: string, + modelLoaderNodeId: string = SDXL_MODEL_LOADER +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as + | MetadataAccumulatorInvocation + | undefined; + + if (loraCount > 0) { + // Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === modelLoaderNodeId && + ['unet'].includes(e.source.field) + ) && + !( + e.source.node_id === modelLoaderNodeId && + ['clip'].includes(e.source.field) + ) && + !( + e.source.node_id === modelLoaderNodeId && + ['clip2'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { model_name, base_model, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + + const loraLoaderNode: SDXLLoraLoaderInvocation = { + type: 'sdxl_lora_loader', + id: currentLoraNodeId, + is_intermediate: true, + lora: { model_name, base_model }, + weight, + }; + + // add the lora to the metadata accumulator + if (metadataAccumulator) { + metadataAccumulator.loras.push({ + lora: { model_name, base_model }, + weight, + }); + } + + // add to graph + graph.nodes[currentLoraNodeId] = loraLoaderNode; + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: modelLoaderNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: modelLoaderNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: modelLoaderNodeId, + field: 'clip2', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip2', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip2', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip2', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts index a260dbc467..0ec4e096d9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts @@ -22,6 +22,7 @@ import { SDXL_LATENTS_TO_LATENTS, SDXL_MODEL_LOADER, } from './constants'; +import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; /** * Builds the Image to Image tab graph. @@ -364,6 +365,8 @@ export const buildLinearSDXLImageToImageGraph = ( }, }); + addSDXLLoRAsToGraph(state, graph, SDXL_LATENTS_TO_LATENTS, SDXL_MODEL_LOADER); + // Add Refiner if enabled if (shouldUseSDXLRefiner) { addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts index c10e7831d3..21b7c1e0ac 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts @@ -4,6 +4,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { @@ -246,6 +247,8 @@ export const buildLinearSDXLTextToImageGraph = ( }, }); + addSDXLLoRAsToGraph(state, graph, SDXL_TEXT_TO_LATENTS, SDXL_MODEL_LOADER); + // Add Refiner if enabled if (shouldUseSDXLRefiner) { addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx index c0b143a557..edc92a56c8 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx @@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces import ParamSDXLPromptArea from './ParamSDXLPromptArea'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; const SDXLImageToImageTabParameters = () => { return ( @@ -12,6 +13,7 @@ const SDXLImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx index 35bc0b4284..325fd7d881 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx @@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters'; import ParamSDXLPromptArea from './ParamSDXLPromptArea'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; const SDXLTextToImageTabParameters = () => { return ( @@ -12,6 +13,7 @@ const SDXLTextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index e093c1c33a..0c52258f9d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -4,6 +4,7 @@ import { ASSETS_CATEGORIES, BoardId, IMAGE_CATEGORIES, + IMAGE_LIMIT, } from 'features/gallery/store/types'; import { keyBy } from 'lodash'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; @@ -167,7 +168,14 @@ export const imagesApi = api.injectEndpoints({ }, }; }, - invalidatesTags: (result, error, imageDTOs) => [], + invalidatesTags: (result, error, { imageDTOs }) => { + // for now, assume bulk delete is all on one board + const boardId = imageDTOs[0]?.board_id; + return [ + { type: 'BoardImagesTotal', id: boardId ?? 'none' }, + { type: 'BoardAssetsTotal', id: boardId ?? 'none' }, + ]; + }, async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) { /** * Cache changes for `deleteImages`: @@ -889,18 +897,25 @@ export const imagesApi = api.injectEndpoints({ board_id, }, }), - invalidatesTags: (result, error, { board_id }) => [ - // update the destination board - { type: 'Board', id: board_id ?? 'none' }, - // update old board totals - { type: 'BoardImagesTotal', id: board_id ?? 'none' }, - { type: 'BoardAssetsTotal', id: board_id ?? 'none' }, - // update the no_board totals - { type: 'BoardImagesTotal', id: 'none' }, - { type: 'BoardAssetsTotal', id: 'none' }, - ], + invalidatesTags: (result, error, { imageDTOs, board_id }) => { + //assume all images are being moved from one board for now + const oldBoardId = imageDTOs[0]?.board_id; + return [ + // update the destination board + { type: 'Board', id: board_id ?? 'none' }, + // update new board totals + { type: 'BoardImagesTotal', id: board_id ?? 'none' }, + { type: 'BoardAssetsTotal', id: board_id ?? 'none' }, + // update old board totals + { type: 'BoardImagesTotal', id: oldBoardId ?? 'none' }, + { type: 'BoardAssetsTotal', id: oldBoardId ?? 'none' }, + // update the no_board totals + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ]; + }, async onQueryStarted( - { board_id, imageDTOs }, + { board_id: new_board_id, imageDTOs }, { dispatch, queryFulfilled, getState } ) { try { @@ -920,7 +935,7 @@ export const imagesApi = api.injectEndpoints({ 'getImageDTO', image_name, (draft) => { - draft.board_id = board_id; + draft.board_id = new_board_id; } ) ); @@ -946,7 +961,7 @@ export const imagesApi = api.injectEndpoints({ ); const queryArgs = { - board_id, + board_id: new_board_id, categories, }; @@ -954,23 +969,24 @@ export const imagesApi = api.injectEndpoints({ queryArgs )(getState()); - const { data: total } = IMAGE_CATEGORIES.includes( + const { data: previousTotal } = IMAGE_CATEGORIES.includes( imageDTO.image_category ) ? boardsApi.endpoints.getBoardImagesTotal.select( - imageDTO.board_id ?? 'none' + new_board_id ?? 'none' )(getState()) : boardsApi.endpoints.getBoardAssetsTotal.select( - imageDTO.board_id ?? 'none' + new_board_id ?? 'none' )(getState()); const isCacheFullyPopulated = - currentCache.data && currentCache.data.ids.length >= (total ?? 0); + currentCache.data && + currentCache.data.ids.length >= (previousTotal ?? 0); - const isInDateRange = getIsImageInDateRange( - currentCache.data, - imageDTO - ); + const isInDateRange = + (previousTotal || 0) >= IMAGE_LIMIT + ? getIsImageInDateRange(currentCache.data, imageDTO) + : true; if (isCacheFullyPopulated || isInDateRange) { // *upsert* to $cache @@ -981,7 +997,7 @@ export const imagesApi = api.injectEndpoints({ (draft) => { imagesAdapter.upsertOne(draft, { ...imageDTO, - board_id, + board_id: new_board_id, }); } ) @@ -1097,10 +1113,10 @@ export const imagesApi = api.injectEndpoints({ const isCacheFullyPopulated = currentCache.data && currentCache.data.ids.length >= (total ?? 0); - const isInDateRange = getIsImageInDateRange( - currentCache.data, - imageDTO - ); + const isInDateRange = + (total || 0) >= IMAGE_LIMIT + ? getIsImageInDateRange(currentCache.data, imageDTO) + : true; if (isCacheFullyPopulated || isInDateRange) { // *upsert* to $cache @@ -1111,7 +1127,7 @@ export const imagesApi = api.injectEndpoints({ (draft) => { imagesAdapter.upsertOne(draft, { ...imageDTO, - board_id: undefined, + board_id: 'none', }); } ) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 6574ec4909..fc3397820e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1443,7 +1443,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; + [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; }; /** * Edges @@ -1486,7 +1486,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; + [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; }; /** * Errors @@ -1904,6 +1904,40 @@ export type components = { */ image_name: string; }; + /** + * ImageHueAdjustmentInvocation + * @description Adjusts the Hue of an image. + */ + ImageHueAdjustmentInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default img_hue_adjust + * @enum {string} + */ + type?: "img_hue_adjust"; + /** + * Image + * @description The image to adjust + */ + image?: components["schemas"]["ImageField"]; + /** + * Hue + * @description The degrees by which to rotate the hue, 0-360 + * @default 0 + */ + hue?: number; + }; /** * ImageInverseLerpInvocation * @description Inverse linear interpolation of all pixels of an image @@ -1984,6 +2018,40 @@ export type components = { */ max?: number; }; + /** + * ImageLuminosityAdjustmentInvocation + * @description Adjusts the Luminosity (Value) of an image. + */ + ImageLuminosityAdjustmentInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default img_luminosity_adjust + * @enum {string} + */ + type?: "img_luminosity_adjust"; + /** + * Image + * @description The image to adjust + */ + image?: components["schemas"]["ImageField"]; + /** + * Luminosity + * @description The factor by which to adjust the luminosity (value) + * @default 1 + */ + luminosity?: number; + }; /** * ImageMetadata * @description An image's generation metadata @@ -2239,6 +2307,40 @@ export type components = { */ resample_mode?: "nearest" | "box" | "bilinear" | "hamming" | "bicubic" | "lanczos"; }; + /** + * ImageSaturationAdjustmentInvocation + * @description Adjusts the Saturation of an image. + */ + ImageSaturationAdjustmentInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default img_saturation_adjust + * @enum {string} + */ + type?: "img_saturation_adjust"; + /** + * Image + * @description The image to adjust + */ + image?: components["schemas"]["ImageField"]; + /** + * Saturation + * @description The factor by which to adjust the saturation + * @default 1 + */ + saturation?: number; + }; /** * ImageScaleInvocation * @description Scales an image by a factor @@ -4912,6 +5014,82 @@ export type components = { */ denoising_end?: number; }; + /** + * SDXLLoraLoaderInvocation + * @description Apply selected lora to unet and text_encoder. + */ + SDXLLoraLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default sdxl_lora_loader + * @enum {string} + */ + type?: "sdxl_lora_loader"; + /** + * Lora + * @description Lora model name + */ + lora?: components["schemas"]["LoRAModelField"]; + /** + * Weight + * @description With what weight to apply lora + * @default 0.75 + */ + weight?: number; + /** + * Unet + * @description UNet model for applying lora + */ + unet?: components["schemas"]["UNetField"]; + /** + * Clip + * @description Clip model for applying lora + */ + clip?: components["schemas"]["ClipField"]; + /** + * Clip2 + * @description Clip2 model for applying lora + */ + clip2?: components["schemas"]["ClipField"]; + }; + /** + * SDXLLoraLoaderOutput + * @description Model loader output + */ + SDXLLoraLoaderOutput: { + /** + * Type + * @default sdxl_lora_loader_output + * @enum {string} + */ + type?: "sdxl_lora_loader_output"; + /** + * Unet + * @description UNet submodel + */ + unet?: components["schemas"]["UNetField"]; + /** + * Clip + * @description Tokenizer and text_encoder submodels + */ + clip?: components["schemas"]["ClipField"]; + /** + * Clip2 + * @description Tokenizer2 and text_encoder2 submodels + */ + clip2?: components["schemas"]["ClipField"]; + }; /** * SDXLModelLoaderInvocation * @description Loads an sdxl base model, outputting its submodels. @@ -5961,6 +6139,24 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; + /** + * ControlNetModelFormat + * @description An enumeration. + * @enum {string} + */ + ControlNetModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionXLModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. @@ -5973,24 +6169,6 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusionXLModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -6101,7 +6279,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -6138,7 +6316,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index ca9dbb3aeb..e7e3accdad 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -166,6 +166,9 @@ export type OnnxModelLoaderInvocation = TypeReq< export type LoraLoaderInvocation = TypeReq< components['schemas']['LoraLoaderInvocation'] >; +export type SDXLLoraLoaderInvocation = TypeReq< + components['schemas']['SDXLLoraLoaderInvocation'] +>; export type MetadataAccumulatorInvocation = TypeReq< components['schemas']['MetadataAccumulatorInvocation'] >; diff --git a/pyproject.toml b/pyproject.toml index b3f12481a8..2ae297a6da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ dependencies = [ "dev" = [ "pudb", ] -"test" = ["pytest>6.0.0", "pytest-cov", "black"] +"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"] "xformers" = [ "xformers~=0.0.19; sys_platform!='darwin'", "triton; sys_platform=='linux'", diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 0000000000..4314bad595 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType + +BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) +VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) + + +@pytest.fixture +def model_manager(datadir) -> ModelManager: + InvokeAIAppConfig.get_config(root=datadir) + return ModelManager(datadir / "configs" / "relative_sub.models.yaml") + + +def test_get_model_names(model_manager: ModelManager): + names = model_manager.model_names() + assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] + + +def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) + top_model_path, is_override = model_manager._get_model_path(model_config) + expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" + assert top_model_path == expected_model_path + assert not is_override + + +def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config( + VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] + ) + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" + assert vae_model_path == expected_vae_path + assert is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml new file mode 100644 index 0000000000..3ec7a3adff --- /dev/null +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -0,0 +1,15 @@ +__metadata__: + version: 3.0.0 + +sdxl/main/SDXL base: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + variant: normal + format: diffusers + +sdxl/main/SDXL with VAE: + path: sdxl/main/SDXL base 1_0 + description: SDXL with customized VAE + vae: sdxl/vae/sdxl-vae-fp16-fix/ + variant: normal + format: diffusers diff --git a/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json b/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json b/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json new file mode 100644 index 0000000000..e69de29bb2