From 0257b4a611baf312d20a92cf8b5cd5ce10c55c78 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 00:13:45 +1000 Subject: [PATCH 01/37] fix(ui): fix mouse interactions --- .../components/IAINode/IAINodeHeader.tsx | 2 ++ .../components/IAINode/IAINodeInputs.tsx | 27 ++++++++++--------- .../nodes/components/InvocationComponent.tsx | 11 +++++++- .../nodes/components/ProgressImageNode.tsx | 2 ++ .../nodes/hooks/useBuildInvocation.ts | 8 ++++++ 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx index 73705769b6..226aaed7be 100644 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx @@ -1,4 +1,5 @@ import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation'; import { memo } from 'react'; import { FaInfoCircle } from 'react-icons/fa'; @@ -12,6 +13,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => { const { nodeId, title, description } = props; return ( { }); return ( - + {IAINodeInputsToRender} ); diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx index 12817679e2..3a08b46dde 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx @@ -23,7 +23,14 @@ export const InvocationComponent = memo((props: NodeProps) => { if (!template) { return ( - + ) => { description={template.description} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx index 6424d4f76c..2975cd820c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx @@ -21,7 +21,9 @@ const ProgressImageNode = (props: NodeProps) => { /> nodes.invocationTemplates ); +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + export const useBuildInvocation = () => { const invocationTemplates = useAppSelector(templatesSelector); @@ -32,6 +38,7 @@ export const useBuildInvocation = () => { }); const node: Node = { + ...SHARED_NODE_PROPERTIES, id: 'progress_image', type: 'progress_image', position: { x: x, y: y }, @@ -91,6 +98,7 @@ export const useBuildInvocation = () => { }); const invocation: Node = { + ...SHARED_NODE_PROPERTIES, id: nodeId, type: 'invocation', position: { x: x, y: y }, From 30e45eaf47c3cced85ed5175c8430c0c0d05d441 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 00:45:26 +1000 Subject: [PATCH 02/37] feat(ui): hold shift to make nodes draggable from anywhere --- .../src/features/nodes/components/InvocationComponent.tsx | 4 ++-- .../web/src/features/nodes/components/NodeWrapper.tsx | 5 +++++ .../web/src/features/nodes/components/ProgressImageNode.tsx | 1 - 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx index 3a08b46dde..608f98d6d2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx @@ -53,14 +53,14 @@ export const InvocationComponent = memo((props: NodeProps) => { description={template.description} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx index 7a76cd5902..dc5a94c267 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx @@ -2,6 +2,8 @@ import { Box, useToken } from '@chakra-ui/react'; import { NODE_MIN_WIDTH } from 'app/constants'; import { PropsWithChildren } from 'react'; +import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation'; +import { useAppSelector } from 'app/store/storeHooks'; type NodeWrapperProps = PropsWithChildren & { selected: boolean; @@ -13,8 +15,11 @@ const NodeWrapper = (props: NodeWrapperProps) => { 'dark-lg', ]); + const shift = useAppSelector((state) => state.hotkeys.shift); + return ( ) => { Date: Sat, 15 Jul 2023 01:04:33 +1000 Subject: [PATCH 03/37] fix(ui): allow decimals in number inputs still some jank but eh --- .../src/common/components/IAINumberInput.tsx | 2 +- .../fields/NumberInputFieldComponent.tsx | 36 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx index 8f675cc148..de3b44564a 100644 --- a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx @@ -28,7 +28,7 @@ import { useState, } from 'react'; -const numberStringRegex = /^-?(0\.)?\.?$/; +export const numberStringRegex = /^-?(0\.)?\.?$/; interface Props extends Omit { label?: string; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx index f5df8989f5..50d69a6496 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/NumberInputFieldComponent.tsx @@ -6,6 +6,7 @@ import { NumberInputStepper, } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; +import { numberStringRegex } from 'common/components/IAINumberInput'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { FloatInputFieldTemplate, @@ -13,7 +14,7 @@ import { IntegerInputFieldTemplate, IntegerInputFieldValue, } from 'features/nodes/types/types'; -import { memo } from 'react'; +import { memo, useEffect, useState } from 'react'; import { FieldComponentProps } from './types'; const NumberInputFieldComponent = ( @@ -23,17 +24,42 @@ const NumberInputFieldComponent = ( > ) => { const { nodeId, field } = props; - const dispatch = useAppDispatch(); + const [valueAsString, setValueAsString] = useState( + String(field.value) + ); - const handleValueChanged = (_: string, value: number) => { - dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value })); + const handleValueChanged = (v: string) => { + setValueAsString(v); + // This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc. + if (!v.match(numberStringRegex)) { + // Cast the value to number. Floor it if it should be an integer. + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: + props.template.type === 'integer' + ? Math.floor(Number(v)) + : Number(v), + }) + ); + } }; + useEffect(() => { + if ( + !valueAsString.match(numberStringRegex) && + field.value !== Number(valueAsString) + ) { + setValueAsString(String(field.value)); + } + }, [field.value, valueAsString]); + return ( From ad076b1174109332a0c00acd420caabcd1158fc4 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 14 Jul 2023 11:14:33 -0400 Subject: [PATCH 04/37] add model directory search route --- invokeai/app/api/routers/models.py | 20 +++- .../app/services/model_manager_service.py | 18 ++- .../backend/model_management/model_manager.py | 86 ++++++--------- .../backend/model_management/model_search.py | 103 ++++++++++++++++++ 4 files changed, 172 insertions(+), 55 deletions(-) create mode 100644 invokeai/backend/model_management/model_search.py diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8dbeaa3d05..8d97a1bda4 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,6 +1,7 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein +import pathlib from typing import Literal, List, Optional, Union from fastapi import Body, Path, Query, Response @@ -191,6 +192,23 @@ async def convert_model( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + +@models_router.get( + "/search", + operation_id="search_for_models", + responses={ + 200: { "description": "Directory searched successfully" }, + 404: { "description": "Invalid directory path" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def search_for_models( + search_path: pathlib.Path = Query(description="Directory path to search for models") +)->List[pathlib.Path]: + if not search_path.is_dir(): + raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) @models_router.put( "/merge/{base_model}", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 1b1c43dc11..9a6ba77c13 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -19,7 +19,7 @@ from invokeai.backend.model_management import ( ModelMerger, MergeInterpolationMethod, ) - +from invokeai.backend.model_management.model_search import FindModels import torch from invokeai.app.models.exceptions import CanceledException @@ -230,7 +230,14 @@ class ModelManagerServiceBase(ABC): :param interp: Interpolation method. None (default) """ pass - + + @abstractmethod + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -558,3 +565,10 @@ class ModelManagerService(ModelManagerServiceBase): except AssertionError as e: raise ValueError(e) return result + + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + search = FindModels(directory,self.logger) + return search.list_models() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0476425c8b..0363b858cf 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, Chdir from .model_cache import ModelCache, ModelLocker +from .model_search import ModelSearch from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, @@ -823,6 +824,7 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. @@ -830,63 +832,42 @@ class ModelManager(object): # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.frontend.install.model_install import ask_user_for_prediction_type - + + + class ScanAndImport(ModelSearch): + def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): + super().__init__(directories, logger) + self.installer = installer + self.ignore = ignore + + def on_search_started(self): + self.new_models_found = dict() + + def on_model_found(self, model: Path): + if model not in self.ignore: + self.new_models_found.update(self.installer.heuristic_import(model)) + + def on_search_completed(self): + self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models') + + def models_found(self): + return self.new_models_found + + installer = ModelInstall(config = self.app_config, model_manager = self, prediction_type_helper = ask_user_for_prediction_type, ) - - scanned_dirs = set() - config = self.app_config - known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} - - for autodir in [config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir]: - if autodir is None: - continue - - self.logger.info(f'Scanning {autodir} for models to import') - installed = dict() - - autodir = self.app_config.root_path / autodir - if not autodir.exists(): - continue - - items_scanned = 0 - new_models_found = dict() - - for root, dirs, files in os.walk(autodir): - items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in known_paths or path.parent in scanned_dirs: - scanned_dirs.add(path) - continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): - try: - new_models_found.update(installer.heuristic_import(path)) - scanned_dirs.add(path) - except ValueError as e: - self.logger.warning(str(e)) - - for f in files: - path = Path(root) / f - if path in known_paths or path.parent in scanned_dirs: - continue - if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - try: - import_result = installer.heuristic_import(path) - new_models_found.update(import_result) - except ValueError as e: - self.logger.warning(str(e)) - - self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') - installed.update(new_models_found) - - return installed + known_paths = {config.root_path / x['path'] for x in self.list_models()} + directories = {config.root_path / x for x in [config.autoimport_dir, + config.lora_dir, + config.embedding_dir, + config.controlnet_dir] + } + scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) + scanner.search() + return scanner.models_found() def heuristic_import(self, items_to_import: Set[str], @@ -924,3 +905,4 @@ class ModelManager(object): successfully_installed.update(installed) self.commit() return successfully_installed + diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py new file mode 100644 index 0000000000..1e282b4bb8 --- /dev/null +++ b/invokeai/backend/model_management/model_search.py @@ -0,0 +1,103 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class for recursive directory search for models. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Set, types +from pathlib import Path + +import invokeai.backend.util.logging as logger + +class ModelSearch(ABC): + def __init__(self, directories: List[Path], logger: types.ModuleType=logger): + """ + Initialize a recursive model directory search. + :param directories: List of directory Paths to recurse through + :param logger: Logger to use + """ + self.directories = directories + self.logger = logger + self._items_scanned = 0 + self._models_found = 0 + self._scanned_dirs = set() + self._scanned_paths = set() + self._pruned_paths = set() + + @abstractmethod + def on_search_started(self): + """ + Called before the scan starts. + """ + pass + + @abstractmethod + def on_model_found(self, model: Path): + """ + Process a found model. Raise an exception if something goes wrong. + :param model: Model to process - could be a directory or checkpoint. + """ + pass + + @abstractmethod + def on_search_completed(self): + """ + Perform some activity when the scan is completed. May use instance + variables, items_scanned and models_found + """ + pass + + def search(self): + self.on_search_started() + for dir in self.directories: + self.walk_directory(dir) + self.on_search_completed() + + def walk_directory(self, path: Path): + for root, dirs, files in os.walk(path): + if str(Path(root).name).startswith('.'): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self._items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path in self._scanned_paths or path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + try: + self.on_model_found(path) + self._models_found += 1 + self._scanned_dirs.add(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: + try: + self.on_model_found(path) + self._models_found += 1 + except Exception as e: + self.logger.warning(str(e)) + +class FindModels(ModelSearch): + def on_search_started(self): + self.models_found: Set[Path] = set() + + def on_model_found(self,model: Path): + self.models_found.add(model) + + def on_search_completed(self): + pass + + def list_models(self) -> List[Path]: + self.search() + return self.models_found + + From 8600aad12ba0bb7e9b2329b6549aa7937e3ded9a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 14 Jul 2023 13:45:16 -0400 Subject: [PATCH 05/37] multiple enhancements to model manager REACT API 1. add a /sync route for synchronizing the in-memory model lists to models.yaml, the models directory, and the autoimport directories. 2. add optional destination_directories to convert_model and merge_model operations. 3. add /ckpt_confs route for retrieving known legacy checkpoint configuration files. 4. add /search route for finding all models in a directory located in the server filesystem --- invokeai/app/api/routers/models.py | 61 +++++++++++++++---- .../app/services/model_manager_service.py | 47 +++++++++++++- .../backend/model_management/model_manager.py | 51 ++++++++++++---- .../backend/model_management/model_merge.py | 7 ++- 4 files changed, 135 insertions(+), 31 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8d97a1bda4..d0e0361ad9 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -132,13 +132,11 @@ async def import_model( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ - 204: { - "description": "Model deleted successfully" - }, - 404: { - "description": "Model not found" - } + 204: { "description": "Model deleted successfully" }, + 404: { "description": "Model not found" } }, + status_code = 204, + response_model = None, ) async def delete_model( base_model: BaseModelType = Path(description="Base model"), @@ -174,14 +172,17 @@ async def convert_model( base_model: BaseModelType = Path(description="Base model"), model_type: ModelType = Path(description="The type of model"), model_name: str = Path(description="model name"), + convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), ) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model""" + """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") + dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None ApiDependencies.invoker.services.model_manager.convert_model(model_name, base_model = base_model, - model_type = model_type + model_type = model_type, + convert_dest_directory = dest, ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, base_model = base_model, @@ -209,6 +210,36 @@ async def search_for_models( if not search_path.is_dir(): raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + +@models_router.get( + "/ckpt_confs", + operation_id="list_ckpt_configs", + responses={ + 200: { "description" : "paths retrieved successfully" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def list_ckpt_configs( +)->List[pathlib.Path]: + """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" + return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() + + +@models_router.get( + "/sync", + operation_id="sync_to_config", + responses={ + 201: { "description": "synchronization successful" }, + }, + status_code = 201, + response_model = None +) +async def sync_to_config( +)->None: + """Call after making changes to models.yaml, autoimport directories or models directory to synchronize + in-memory data structures with disk data structures.""" + return ApiDependencies.invoker.services.model_manager.sync_to_config() @models_router.put( "/merge/{base_model}", @@ -228,17 +259,21 @@ async def merge_models( alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), + merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names}") + logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, base_model, - merged_model_name or "+".join(model_names), - alpha, - interp, - force) + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory = dest + ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, base_model = base_model, model_type = ModelType.Main, diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 9a6ba77c13..3c5dad7b3e 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -167,6 +167,15 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def list_checkpoint_configs( + self + )->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + pass + @abstractmethod def convert_model( self, @@ -220,6 +229,7 @@ class ModelManagerServiceBase(ABC): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -228,6 +238,7 @@ class ModelManagerServiceBase(ABC): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ pass @@ -238,6 +249,15 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -438,16 +458,18 @@ class ModelManagerService(ModelManagerServiceBase): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. + as well. """ self.logger.debug(f'delete model {model_name}') self.mgr.del_model(model_name, base_model, model_type) + self.mgr.commit() def convert_model( self, model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -456,13 +478,14 @@ class ModelManagerService(ModelManagerServiceBase): :param model_name: Name of the model to convert :param base_model: Base model type :param model_type: Type of model ['vae' or 'main'] + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) This will raise a ValueError unless the model is not a checkpoint. It will also raise a ValueError in the event that there is a similarly-named diffusers directory already in place. """ self.logger.debug(f'convert model {model_name}') - return self.mgr.convert_model(model_name, base_model, model_type) + return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) def commit(self, conf_file: Optional[Path]=None): """ @@ -543,6 +566,7 @@ class ModelManagerService(ModelManagerServiceBase): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -551,6 +575,7 @@ class ModelManagerService(ModelManagerServiceBase): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ merger = ModelMerger(self.mgr) try: @@ -561,6 +586,7 @@ class ModelManagerService(ModelManagerServiceBase): alpha = alpha, interp = interp, force = force, + merge_dest_directory=merge_dest_directory, ) except AssertionError as e: raise ValueError(e) @@ -572,3 +598,20 @@ class ModelManagerService(ModelManagerServiceBase): """ search = FindModels(directory,self.logger) return search.list_models() + + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + return self.mgr.sync_to_config() + + def list_checkpoint_configs(self)->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + config = self.mgr.app_config + conf_path = config.legacy_conf_path + root_path = config.root_path + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0363b858cf..d28df6e900 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -323,16 +323,7 @@ class ModelManager(object): self.config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: metadata not found # TODO: version check - - self.models = dict() - for model_key, model_config in config.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary + self.app_config = InvokeAIAppConfig.get_config() self.logger = logger self.cache = ModelCache( @@ -343,11 +334,41 @@ class ModelManager(object): sequential_offload = sequential_offload, logger = logger, ) + + self._read_models(config) + + def _read_models(self, config: Optional[DictConfig] = None): + if not config: + if self.config_path: + config = OmegaConf.load(self.config_path) + else: + return + + self.models = dict() + for model_key, model_config in config.items(): + if model_key.startswith('_'): + continue + model_name, base_model, model_type = self.parse_key(model_key) + model_class = MODEL_CLASSES[base_model][model_type] + # alias for config file + model_config["model_format"] = model_config.pop("format") + self.models[model_key] = model_class.create_config(**model_config) + + # check config version number and update on disk/RAM if necessary self.cache_keys = dict() # add controlnet, lora and textual_inversion models from disk self.scan_models_directory() + def sync_to_config(self): + """ + Call this when `models.yaml` has been changed externally. + This will reinitialize internal data structures + """ + # Reread models directory; note that this will reinitialize the cache, + # causing otherwise unreferenced models to be removed from memory + self._read_models() + def model_exists( self, model_name: str, @@ -528,7 +549,10 @@ class ModelManager(object): model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) models = [] for model_key in model_keys: - model_config = self.models[model_key] + model_config = self.models.get(model_key) + if not model_config: + self.logger.error(f'Unknown model {model_name}') + raise KeyError(f'Unknown model {model_name}') cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) if base_model is not None and cur_base_model != base_model: @@ -651,6 +675,7 @@ class ModelManager(object): model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + dest_directory: Optional[Path]=None, ) -> AddModelResult: ''' Convert a checkpoint file into a diffusers folder, deleting the cached @@ -677,14 +702,14 @@ class ModelManager(object): ) checkpoint_path = self.app_config.root_path / info["path"] old_diffusers_path = self.app_config.models_path / model.location - new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name + new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name if new_diffusers_path.exists(): raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") try: move(old_diffusers_path,new_diffusers_path) info["model_format"] = "diffusers" - info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) + info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path)) info.pop('config') result = self.add_model(model_name, base_model, model_type, diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 39f951d2b4..6427b9e430 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -11,7 +11,7 @@ from enum import Enum from pathlib import Path from diffusers import DiffusionPipeline from diffusers import logging as dlogging -from typing import List, Union +from typing import List, Union, Optional import invokeai.backend.util.logging as logger @@ -74,6 +74,7 @@ class ModelMerger(object): alpha: float = 0.5, interp: MergeInterpolationMethod = None, force: bool = False, + merge_dest_directory: Optional[Path] = None, **kwargs, ) -> AddModelResult: """ @@ -85,7 +86,7 @@ class ModelMerger(object): :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ @@ -111,7 +112,7 @@ class ModelMerger(object): merged_pipe = self.merge_diffusion_models( model_paths, alpha, merge_method, force, **kwargs ) - dump_path = config.models_path / base_model.value / ModelType.Main.value + dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name From 7093e5d033726582730cb017efff2e9d968993c9 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 15 Jul 2023 00:52:54 +0300 Subject: [PATCH 06/37] Pad conditionings using zeros and encoder_attention_mask --- invokeai/app/invocations/compel.py | 5 +-- .../diffusion/shared_invokeai_diffusion.py | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 303e0a0c84..a5a9701149 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation): text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + truncate_long_prompts=False, ) conjunction = Compel.parse_prompt_string(self.prompt) @@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation): c, options = compel.build_conditioning_tensor_for_prompt_object( prompt) - # TODO: long prompt support - # if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( tokens_count_including_eos_bos=get_max_token_count( tokenizer, conjunction), diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 1175475bba..307e949ef8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent: def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path + + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat([ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + cond = torch.cat([ + cond, + torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([ + encoder_attention_mask, + conditioning_attention_mask, + ]) + + return cond, encoder_attention_mask + x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, sigma_twice, both_conditionings, + encoder_attention_mask=encoder_attention_mask, + **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x From 8cb19578c27bdf2cea0b87952623d4c2585c1256 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 11:07:13 +1000 Subject: [PATCH 07/37] fix(ui): fix crash on LoRA remove / weight change --- invokeai/frontend/web/src/features/lora/store/loraSlice.ts | 3 ++- .../src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 2dc739a737..f0067a85a2 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -3,6 +3,7 @@ import { LoRAModelParam } from 'features/parameters/types/parameterSchemas'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; export type LoRA = LoRAModelParam & { + id: string; weight: number; }; @@ -24,7 +25,7 @@ export const loraSlice = createSlice({ reducers: { loraAdded: (state, action: PayloadAction) => { const { model_name, id, base_model } = action.payload; - state.loras[id] = { model_name, base_model, ...defaultLoRAConfig }; + state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; }, loraRemoved: (state, action: PayloadAction) => { const id = action.payload; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index 5d1f3d05d2..a2cf1477f2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -60,7 +60,7 @@ export const addLoRAsToGraph = ( const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, - lora, + lora: { model_name, base_model }, weight, }; From 194434dbfa23706270a5a637c790918887a1e124 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 14 Jul 2023 13:14:41 -0400 Subject: [PATCH 08/37] restore scrollbar --- .../components/ImageGrid/GalleryImageGrid.tsx | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx index 858eeedaa3..8b44b39ae9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx @@ -118,6 +118,20 @@ const GalleryImageGrid = () => { ); }, [dispatch, imageNames.length, galleryView]); + useEffect(() => { + // Set up gallery scroler + const { current: root } = rootRef; + if (scroller && root) { + initialize({ + target: root, + elements: { + viewport: scroller, + }, + }); + } + return () => osInstance()?.destroy(); + }, [scroller, initialize, osInstance]); + const handleEndReached = useMemo(() => { if (areMoreAvailable) { return handleLoadMoreImages; From 2faa7cee37f715df6de448bf72176309ead44070 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 14 Jul 2023 23:03:18 -0400 Subject: [PATCH 09/37] add rename_model route --- invokeai/app/api/routers/models.py | 89 ++++++++++++++++++- .../app/services/model_manager_service.py | 35 ++++++++ .../backend/install/model_install_backend.py | 2 - .../backend/model_management/model_manager.py | 49 ++++++++++ .../model_management/models/__init__.py | 4 +- .../backend/model_management/models/base.py | 1 - .../models/stable_diffusion.py | 3 +- 7 files changed, 175 insertions(+), 8 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index d0e0361ad9..c298114cbc 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -23,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @@ -79,7 +80,7 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, @@ -95,7 +96,7 @@ async def import_model( prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL """ + """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } @@ -127,7 +128,91 @@ async def import_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) +@models_router.post( + "/add", + operation_id="add_model", + responses= { + 201: {"description" : "The model added successfully"}, + 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def add_model( + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> ImportModelResponse: + """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + + logger = ApiDependencies.invoker.services.logger + try: + ApiDependencies.invoker.services.model_manager.add_model( + info.model_name, + info.base_model, + info.model_type, + model_attributes = info.dict() + ) + logger.info(f'Successfully added {info.model_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.model_name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + +@models_router.post( + "/rename/{base_model}/{model_type}/{model_name}", + operation_id="rename_model", + responses= { + 201: {"description" : "The model was renamed successfully"}, + 404: {"description" : "The model could not be found"}, + 409: {"description" : "There is already a model corresponding to the new name"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def rename_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="current model name"), + new_name: Optional[str] = Query(description="new model name", default=None), + new_base: Optional[BaseModelType] = Query(description="new model base", default=None), +) -> ImportModelResponse: + """ Rename a model""" + + logger = ApiDependencies.invoker.services.logger + + try: + result = ApiDependencies.invoker.services.model_manager.rename_model( + base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + logger.debug(result) + logger.info(f'Successfully renamed {model_name}=>{new_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=new_name or model_name, + base_model=new_base or base_model, + model_type=model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 3c5dad7b3e..67db5c9478 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -167,6 +167,18 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str, + ): + """ + Rename the indicated model. + """ + pass + @abstractmethod def list_checkpoint_configs( self @@ -615,3 +627,26 @@ class ModelManagerService(ModelManagerServiceBase): conf_path = config.legacy_conf_path root_path = config.root_path return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] + + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + """ + Rename the indicated model. Can provide a new name and/or a new base. + :param model_name: Current name of the model + :param base_model: Current base of the model + :param model_type: Model type (can't be changed) + :param new_name: New name for the model + :param new_base: New base for the model + """ + self.mgr.rename_model(base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index b6f6d62d97..2e537313ac 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -71,8 +71,6 @@ class ModelInstallList: class InstallSelections(): install_models: List[str]= field(default_factory=list) remove_models: List[str]=field(default_factory=list) -# scan_directory: Path = None -# autoscan_on_startup: bool=False @dataclass class ModelLoadInfo(): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index f4485bf67a..55f6de9b5b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -671,6 +671,55 @@ class ModelManager(object): config = model_config, ) + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + ''' + Rename or rebase a model. + ''' + if new_name is None and new_base is None: + self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") + return + + model_key = self.create_key(model_name, base_model, model_type) + model_cfg = self.models.get(model_key, None) + if not model_cfg: + raise KeyError(f"Unknown model: {model_key}") + + old_path = self.app_config.root_path / model_cfg.path + new_name = new_name or model_name + new_base = new_base or base_model + new_key = self.create_key(new_name, new_base, model_type) + if new_key in self.models: + raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') + + # if this is a model file/directory that we manage ourselves, we need to move it + if old_path.is_relative_to(self.app_config.models_path): + new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name + move(old_path, new_path) + model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) + + # clean up caches + old_model_cache = self._get_model_cache_path(old_path) + if old_model_cache.exists(): + if old_model_cache.is_dir(): + rmtree(str(old_model_cache)) + else: + old_model_cache.unlink() + + cache_ids = self.cache_keys.pop(model_key, []) + for cache_id in cache_ids: + self.cache.uncache_model(cache_id) + + self.models.pop(model_key, None) # delete + self.models[new_key] = model_cfg + self.commit() + def convert_model ( self, model_name: str, diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1c573b26b6..e404c56bdf 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items(): model_configs.discard(None) MODEL_CONFIGS.extend(model_configs) - for cfg in model_configs: + # LS: sort to get the checkpoint configs first, which makes + # for a better template in the Swagger docs + for cfg in sorted(model_configs, key=lambda x: str(x)): model_name, cfg_name = cfg.__qualname__.split('.')[-2:] openapi_cfg_name = model_name + cfg_name if openapi_cfg_name in vars(): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index ddbc401e5b..c569872a81 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel): path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) - # do not save to config error: Optional[ModelError] = Field(None) class Config: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 74751a40dd..3d2e50d8fb 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel): vae: Optional[str] = Field(None) config: str variant: ModelVariantType - - + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main From 6ab9a5e108b09df8a9295f324ef15fbc74b67149 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 5 Jul 2023 20:00:43 +0300 Subject: [PATCH 10/37] Draft --- .../controlnet_image_processors.py | 17 ++- invokeai/app/invocations/latent.py | 101 +++++++++++------- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c37dcda998..c9fad11987 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict from PIL import Image from pydantic import BaseModel, Field, validator +from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, @@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] +class ControlNetModelField(BaseModel): + """ControlNet model field""" + + model_name: str = Field(description="Name of the ControlNet model") + base_model: BaseModelType = Field(description="Base model") + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") - control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -154,7 +161,7 @@ class ControlNetInvocation(BaseInvocation): type: Literal["controlnet"] = "controlnet" # Inputs image: ImageField = Field(default=None, description="The control image") - control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -182,7 +189,11 @@ class ControlNetInvocation(BaseInvocation): return ControlOutput( control=ControlField( image=self.image, - control_model=self.control_model, + #control_model=self.control_model, + control_model=ControlNetModelField( + model_name="canny", + base_model=BaseModelType.StableDiffusion1, + ), control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3f95f3658..1e41a9c96f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -71,16 +71,21 @@ def get_scheduler( scheduler_name: str, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim']) + scheduler_name, SCHEDULER_MAP['ddim'] + ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict()) + **scheduler_info.dict() + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, ** - scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = { + **scheduler_config, + **scheduler_extra_config, + "_backup": scheduler_config, + } scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py @@ -137,8 +142,11 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, - intermediate_state: PipelineIntermediateState) -> None: + self, + context: InvocationContext, + source_node_id: str, + intermediate_state: PipelineIntermediateState, + ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -147,11 +155,16 @@ class TextToLatentsInvocation(BaseInvocation): ) def get_conditioning_data( - self, context: InvocationContext, scheduler) -> ConditioningData: + self, + context: InvocationContext, + scheduler, + ) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get( - self.positive_conditioning.conditioning_name) + self.positive_conditioning.conditioning_name + ) uc, _ = context.services.latents.get( - self.negative_conditioning.conditioning_name) + self.negative_conditioning.conditioning_name + ) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -178,7 +191,10 @@ class TextToLatentsInvocation(BaseInvocation): return conditioning_data def create_pipeline( - self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + self, + unet, + scheduler, + ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( # unet, @@ -213,6 +229,7 @@ class TextToLatentsInvocation(BaseInvocation): model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], + exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: @@ -238,25 +255,19 @@ class TextToLatentsInvocation(BaseInvocation): control_data = [] control_models = [] for control_info in control_list: - # handle control models - if ("," in control_info.control_model): - control_model_split = control_info.control_model.split(",") - control_name = control_model_split[0] - control_subfolder = control_model_split[1] - print("Using HF model subfolders") - print(" control_name: ", control_name) - print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained( - control_name, subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to( - model.device) - else: - control_model = ControlNetModel.from_pretrained( - control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) + control_model = exit_stack.enter_context( + context.model_manager.get_model( + model_name=control_info.control_model.model_name, + model_type=ModelType.ControlNet, + base_model=control_info.control_model.base_model, + ) + ) + control_models.append(control_model) control_image_field = control_info.image input_image = context.services.images.get_pil_image( - control_image_field.image_name) + control_image_field.image_name + ) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -278,7 +289,8 @@ class TextToLatentsInvocation(BaseInvocation): weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode,) + control_mode=control_info.control_mode, + ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data @@ -289,7 +301,8 @@ class TextToLatentsInvocation(BaseInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -298,14 +311,17 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -322,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size @@ -374,7 +391,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -383,14 +401,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -407,11 +428,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype) + latent, device=unet.device, dtype=latent.dtype + ) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -535,7 +558,8 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -569,7 +593,8 @@ class ScaleLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() From 788dcbde7010eafb6d2431909b051193fe38b618 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:28:26 +1000 Subject: [PATCH 11/37] fix(nodes): add missing import --- invokeai/app/invocations/latent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1e41a9c96f..dc929b5833 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops From 82fa39b531460a36862d76b65add0578c2969e42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:31:32 +1000 Subject: [PATCH 12/37] feat(nodes): add controlnet nodes type hint --- invokeai/app/invocations/controlnet_image_processors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c9fad11987..06b4d2b2ef 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -141,6 +141,7 @@ class ControlField(BaseModel): "ui": { "type_hints": { "control_weight": "float", + "control_model": "controlnet_model", # "control_weight": "number", } } From 96c9db6d2e0c3b044c7ca1656e5ce6f7946b31bd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:31:45 +1000 Subject: [PATCH 13/37] chore(ui): typegen --- .../frontend/web/src/services/api/schema.d.ts | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index acbed14eac..2a9275de73 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -734,7 +734,7 @@ export type components = { * Control Model * @description The ControlNet model to use */ - control_model: string; + control_model: components["schemas"]["ControlNetModelField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -791,10 +791,9 @@ export type components = { /** * Control Model * @description control model used - * @default lllyasviel/sd-controlnet-canny - * @enum {string} + * @default lllyasviel/sd-controlnet-canny */ - control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace"; + control_model?: components["schemas"]["ControlNetModelField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -838,6 +837,19 @@ export type components = { model_format: components["schemas"]["ControlNetModelFormat"]; error?: components["schemas"]["ModelError"]; }; + /** + * ControlNetModelField + * @description ControlNet model field + */ + ControlNetModelField: { + /** + * Model Name + * @description Name of the ControlNet model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** * ControlNetModelFormat * @description An enumeration. @@ -3290,7 +3302,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; }; /** * MultiplyInvocation @@ -4605,18 +4617,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -4997,7 +5009,7 @@ export type operations = { /** @description The model imported successfully */ 201: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; }; }; /** @description The model could not be found */ @@ -5065,14 +5077,14 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; }; }; /** @description Bad request */ @@ -5106,7 +5118,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; }; }; /** @description Bad request */ @@ -5141,7 +5153,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; }; }; /** @description Incompatible models */ From 29b2e59e655ec931a1d284cfcd0a85d77457a94e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:47:27 +1000 Subject: [PATCH 14/37] fix(nodes): fix ref to ctx mgr service, missing import --- invokeai/app/invocations/latent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index dc929b5833..baf78c7c23 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -12,6 +12,7 @@ from pydantic import BaseModel, Field, validator from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.models.base import ModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState @@ -257,7 +258,7 @@ class TextToLatentsInvocation(BaseInvocation): control_models = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.model_manager.get_model( + context.services.model_manager.get_model( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, From 5ac114576fc57f05c7885a326bf1d955487338c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:47:52 +1000 Subject: [PATCH 15/37] feat(ui): add controlnet field to nodes --- .../nodes/components/InputFieldComponent.tsx | 11 ++ .../ControlNetModelInputFieldComponent.tsx | 102 ++++++++++++++++++ .../web/src/features/nodes/types/constants.ts | 8 ++ .../web/src/features/nodes/types/types.ts | 13 +++ .../nodes/util/fieldTemplateBuilders.ts | 19 ++++ .../features/nodes/util/fieldValueBuilders.ts | 4 + .../util/modelIdToControlNetModelField.ts | 14 +++ .../frontend/web/src/services/api/types.d.ts | 2 + 8 files changed, 173 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index b179adff23..23effc5375 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; +import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; @@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'controlnet_model' && template.type === 'controlnet_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: controlNetModels } = useGetControlNetModelsQuery(); + + const selectedModel = useMemo( + () => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]], + [controlNetModels?.entities, controlNetModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!controlNetModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(controlNetModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [controlNetModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && controlNetModels?.ids.includes(field.value)) { + return; + } + + const firstLora = controlNetModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, controlNetModels?.ids]); + + return ( + + ); +}; + +export default memo(ControlNetModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 5fe780a286..3a70e52ee5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record = { model: 'model', vae_model: 'vae_model', lora_model: 'lora_model', + controlnet_model: 'controlnet_model', + ControlNetModelField: 'controlnet_model', array: 'array', item: 'item', ColorField: 'color', @@ -130,6 +132,12 @@ export const FIELDS: Record = { title: 'LoRA', description: 'Models are models.', }, + controlnet_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'ControlNet', + description: 'Models are models.', + }, array: { color: 'gray', colorCssVar: getColorTokenCssVariable('gray'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 4c47c63068..18b837a98e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -71,6 +71,7 @@ export type FieldType = | 'model' | 'vae_model' | 'lora_model' + | 'controlnet_model' | 'array' | 'item' | 'color' @@ -100,6 +101,7 @@ export type InputFieldValue = | MainModelInputFieldValue | VaeModelInputFieldValue | LoRAModelInputFieldValue + | ControlNetModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -127,6 +129,7 @@ export type InputFieldTemplate = | ModelInputFieldTemplate | VaeModelInputFieldTemplate | LoRAModelInputFieldTemplate + | ControlNetModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -249,6 +252,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & { value?: LoRAModelParam; }; +export type ControlNetModelInputFieldValue = FieldValueBase & { + type: 'controlnet_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -368,6 +376,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { type: 'lora_model'; }; +export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'controlnet_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 1c2dbc0c3e..eaa7fe66fc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -9,6 +9,7 @@ import { ColorInputFieldTemplate, ConditioningInputFieldTemplate, ControlInputFieldTemplate, + ControlNetModelInputFieldTemplate, EnumInputFieldTemplate, FieldType, FloatInputFieldTemplate, @@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({ return template; }; +const buildControlNetModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => { + const template: ControlNetModelInputFieldTemplate = { + ...baseField, + type: 'controlnet_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -479,6 +495,9 @@ export const buildInputFieldTemplate = ( if (['lora_model'].includes(fieldType)) { return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); } + if (['controlnet_model'].includes(fieldType)) { + return buildControlNetModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 950038b691..f54a7640bd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -83,6 +83,10 @@ export const buildInputFieldValue = ( if (template.type === 'lora_model') { fieldValue.value = undefined; } + + if (template.type === 'controlnet_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts new file mode 100644 index 0000000000..655d5cd5df --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts @@ -0,0 +1,14 @@ +import { BaseModelType, ControlNetModelField } from 'services/api/types'; + +export const modelIdToControlNetModelField = ( + controlNetModelId: string +): ControlNetModelField => { + const [base_model, model_type, model_name] = controlNetModelId.split('/'); + + const field: ControlNetModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index fcbbd1a6a0..37faae592f 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType']; export type MainModelField = components['schemas']['MainModelField']; export type VAEModelField = components['schemas']['VAEModelField']; export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ControlNetModelField = + components['schemas']['ControlNetModelField']; export type ModelsList = components['schemas']['ModelsList']; export type ControlField = components['schemas']['ControlField']; From 76dc47e88df9c817346b1a4409e2847eddef3dd3 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Tue, 11 Jul 2023 16:18:38 -0400 Subject: [PATCH 16/37] remove frontend constants, use backend response for controlnet models. add disabled state if base model is not compatible. clear control net model if main base model changes. add logic to guess processor and move it up in UI --- .../frontend/web/src/app/types/invokeai.ts | 4 +- .../controlNet/components/ControlNet.tsx | 19 ++-- .../parameters/ParamControlNetModel.tsx | 78 ++++++++------ .../features/controlNet/store/constants.ts | 100 +++--------------- .../controlNet/store/controlNetSlice.ts | 33 ++++-- 5 files changed, 97 insertions(+), 137 deletions(-) diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 40b8c1c73a..be642a6435 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,5 +1,5 @@ import { - CONTROLNET_MODELS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; @@ -128,7 +128,7 @@ export type AppConfig = { canRestoreDeletedImagesFromBin: boolean; sd: { defaultModel?: string; - disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; + disabledControlNetModels: string[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; iterations: { initial: number; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index bb01416e1d..e25c320cd6 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -124,6 +124,7 @@ const ControlNet = (props: ControlNetProps) => { /> } /> + {!shouldAutoConfig && ( { /> )} + + + + {isEnabled && ( <> @@ -196,18 +207,10 @@ const ControlNet = (props: ControlNetProps) => { height={96} /> - - )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index ddf266ccfc..eda3cde5d2 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -1,55 +1,71 @@ -import { createSelector } from '@reduxjs/toolkit'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIMantineSearchableSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSearchableSelect'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -import { - CONTROLNET_MODELS, - ControlNetModelName, -} from 'features/controlNet/store/constants'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { map } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; +import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; type ParamControlNetModelProps = { controlNetId: string; - model: ControlNetModelName; + model: string; }; -const selector = createSelector(configSelector, (config) => { - const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({ - label: m.label, - value: m.type, - })).filter( - (d) => - !config.sd.disabledControlNetModels.includes( - d.value as ControlNetModelName - ) - ); - - return controlNetModels; -}); - const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { controlNetId, model } = props; - const controlNetModels = useAppSelector(selector); const dispatch = useAppDispatch(); const isReady = useIsReadyToInvoke(); + const currentMainModel = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: controlNetModels } = useGetControlNetModelsQuery(); + const handleModelChanged = useCallback( (val: string | null) => { - // TODO: do not cast - const model = val as ControlNetModelName; - dispatch(controlNetModelChanged({ controlNetId, model })); + if (!val) return; + dispatch(controlNetModelChanged({ controlNetId, model: val })); }, [controlNetId, dispatch] ); + const data = useMemo(() => { + if (!controlNetModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(controlNetModels.entities, (model, id) => { + if (!model) { + return; + } + + const disabled = currentMainModel?.base_model !== model.base_model; + + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + disabled, + tooltip: disabled + ? `Incompatible base model: ${model.base_model}` + : undefined, + }); + }); + + return data; + }, [controlNetModels, currentMainModel?.base_model]); + return ( ; - -type ControlNetModel = { - type: string; - label: string; - description?: string; - defaultProcessor?: ControlNetProcessorType; +export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: { + [key: string]: ControlNetProcessorType; +} = { + canny: 'canny_image_processor', + mlsd: 'mlsd_image_processor', + depth: 'midas_depth_image_processor', + bae: 'normalbae_image_processor', + lineart: 'lineart_image_processor', + lineart_anime: 'lineart_anime_image_processor', + softedge: 'hed_image_processor', + shuffle: 'content_shuffle_image_processor', + openpose: 'openpose_image_processor', + mediapipe: 'mediapipe_face_processor', }; - -export const CONTROLNET_MODELS: ControlNetModelsDict = { - 'lllyasviel/control_v11p_sd15_canny': { - type: 'lllyasviel/control_v11p_sd15_canny', - label: 'Canny', - defaultProcessor: 'canny_image_processor', - }, - 'lllyasviel/control_v11p_sd15_inpaint': { - type: 'lllyasviel/control_v11p_sd15_inpaint', - label: 'Inpaint', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_mlsd': { - type: 'lllyasviel/control_v11p_sd15_mlsd', - label: 'M-LSD', - defaultProcessor: 'mlsd_image_processor', - }, - 'lllyasviel/control_v11f1p_sd15_depth': { - type: 'lllyasviel/control_v11f1p_sd15_depth', - label: 'Depth', - defaultProcessor: 'midas_depth_image_processor', - }, - 'lllyasviel/control_v11p_sd15_normalbae': { - type: 'lllyasviel/control_v11p_sd15_normalbae', - label: 'Normal Map (BAE)', - defaultProcessor: 'normalbae_image_processor', - }, - 'lllyasviel/control_v11p_sd15_seg': { - type: 'lllyasviel/control_v11p_sd15_seg', - label: 'Segmentation', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_lineart': { - type: 'lllyasviel/control_v11p_sd15_lineart', - label: 'Lineart', - defaultProcessor: 'lineart_image_processor', - }, - 'lllyasviel/control_v11p_sd15s2_lineart_anime': { - type: 'lllyasviel/control_v11p_sd15s2_lineart_anime', - label: 'Lineart Anime', - defaultProcessor: 'lineart_anime_image_processor', - }, - 'lllyasviel/control_v11p_sd15_scribble': { - type: 'lllyasviel/control_v11p_sd15_scribble', - label: 'Scribble', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11p_sd15_softedge': { - type: 'lllyasviel/control_v11p_sd15_softedge', - label: 'Soft Edge', - defaultProcessor: 'hed_image_processor', - }, - 'lllyasviel/control_v11e_sd15_shuffle': { - type: 'lllyasviel/control_v11e_sd15_shuffle', - label: 'Content Shuffle', - defaultProcessor: 'content_shuffle_image_processor', - }, - 'lllyasviel/control_v11p_sd15_openpose': { - type: 'lllyasviel/control_v11p_sd15_openpose', - label: 'Openpose', - defaultProcessor: 'openpose_image_processor', - }, - 'lllyasviel/control_v11f1e_sd15_tile': { - type: 'lllyasviel/control_v11f1e_sd15_tile', - label: 'Tile (experimental)', - defaultProcessor: 'none', - }, - 'lllyasviel/control_v11e_sd15_ip2p': { - type: 'lllyasviel/control_v11e_sd15_ip2p', - label: 'Pix2Pix (experimental)', - defaultProcessor: 'none', - }, - 'CrucibleAI/ControlNetMediaPipeFace': { - type: 'CrucibleAI/ControlNetMediaPipeFace', - label: 'Mediapipe Face', - defaultProcessor: 'mediapipe_face_processor', - }, -}; - -export type ControlNetModelName = keyof typeof CONTROLNET_MODELS; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index d1c69566e9..39a321b282 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -8,9 +8,10 @@ import { RequiredControlNetProcessorNode, } from './types'; import { - CONTROLNET_MODELS, + CONTROLNET_MODEL_DEFAULT_PROCESSORS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, - ControlNetModelName, + // ControlNetModelName, } from './constants'; import { controlNetImageProcessed } from './actions'; import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image'; @@ -26,7 +27,7 @@ export type ControlModes = export const initialControlNet: Omit = { isEnabled: true, - model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, + model: '', weight: 1, beginStepPct: 0, endStepPct: 1, @@ -42,7 +43,7 @@ export const initialControlNet: Omit = { export type ControlNetConfig = { controlNetId: string; isEnabled: boolean; - model: ControlNetModelName; + model: string; weight: number; beginStepPct: number; endStepPct: number; @@ -147,7 +148,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - model: ControlNetModelName; + model: string; }> ) => { const { controlNetId, model } = action.payload; @@ -155,7 +156,15 @@ export const controlNetSlice = createSlice({ state.controlNets[controlNetId].processedControlImage = null; if (state.controlNets[controlNetId].shouldAutoConfig) { - const processorType = CONTROLNET_MODELS[model].defaultProcessor; + let processorType: ControlNetProcessorType | undefined = undefined; + + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (model.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ @@ -241,9 +250,15 @@ export const controlNetSlice = createSlice({ if (newShouldAutoConfig) { // manage the processor for the user - const processorType = - CONTROLNET_MODELS[state.controlNets[controlNetId].model] - .defaultProcessor; + let processorType: ControlNetProcessorType | undefined = undefined; + + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (state.controlNets[controlNetId].model.includes(modelSubstring)) { + processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; + break; + } + } + if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ From 0d413464175fd826c2c69ca8c67259c654704f73 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 12:19:24 +1000 Subject: [PATCH 17/37] feat(ui): fix controlNet models - update controlnet state to use object format for model - update model-parsing helper functions to log errors - update nodes components, types and state - remove controlnets from state when models are loaded and the controlnet's model is not available --- .../listeners/modelSelected.ts | 10 ++- .../listeners/modelsLoaded.ts | 18 +++- .../src/common/hooks/useIsReadyToInvoke.ts | 8 ++ .../controlNet/components/ControlNet.tsx | 2 +- .../ParamControlNetIsPreprocessed.tsx | 36 -------- .../parameters/ParamControlNetModel.tsx | 82 ++++++++++++++----- .../controlNet/store/controlNetSlice.ts | 57 +++++-------- .../ControlNetModelInputFieldComponent.tsx | 53 ++++++------ .../src/features/nodes/store/nodesSlice.ts | 4 +- .../web/src/features/nodes/types/types.ts | 3 +- .../util/modelIdToControlNetModelField.ts | 14 ---- .../VAEModel/ParamVAEModelSelect.tsx | 1 + .../parameters/types/parameterSchemas.ts | 17 ++++ .../util/modelIdToControlNetModelParam.ts | 30 +++++++ .../util/modelIdToLoRAModelParam.ts | 14 +++- .../util/modelIdToMainModelParam.ts | 14 +++- .../parameters/util/modelIdToVAEModelParam.ts | 14 +++- .../frontend/web/src/services/api/schema.d.ts | 26 +++--- 18 files changed, 249 insertions(+), 154 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetIsPreprocessed.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts create mode 100644 invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index ee879a8915..05076960fb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { forEach } from 'lodash-es'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -51,7 +52,14 @@ export const addModelSelectedListener = () => { modelsCleared += 1; } - // TODO: handle incompatible controlnet; pending model manager support + const { controlNets } = state.controlNet; + forEach(controlNets, (controlNet, controlNetId) => { + if (controlNet.model?.base_model !== base_model) { + dispatch(controlNetRemoved({ controlNetId })); + modelsCleared += 1; + } + }); + if (modelsCleared > 0) { dispatch( addToast( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index f8abcfa758..5e3caa7c99 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -11,6 +11,7 @@ import { import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => { matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, effect: async (action, { getState, dispatch }) => { // ControlNet models loaded - need to remove missing ControlNets from state - // TODO: pending model manager controlnet support + const controlNets = getState().controlNet.controlNets; + + forEach(controlNets, (controlNet, controlNetId) => { + const isControlNetAvailable = some( + action.payload.entities, + (m) => + m?.model_name === controlNet?.model?.model_name && + m?.base_model === controlNet?.model?.base_model + ); + + if (isControlNetAvailable) { + return; + } + + dispatch(controlNetRemoved({ controlNetId })); + }); }, }); }; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index 3b1476fb1f..580206266d 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { modelsApi } from '../../services/api/endpoints/models'; +import { forEach } from 'lodash-es'; const readinessSelector = createSelector( [stateSelector, activeTabNameSelector], @@ -52,6 +53,13 @@ const readinessSelector = createSelector( reasonsWhyNotReady.push('Seed-Weights badly formatted.'); } + forEach(state.controlNet.controlNets, (controlNet, id) => { + if (!controlNet.model) { + isReady = false; + reasonsWhyNotReady.push('ControlNet ${id} has no model selected.'); + } + }); + // All good return { isReady, reasonsWhyNotReady }; }, diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index e25c320cd6..cd3cafa12a 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -90,7 +90,7 @@ const ControlNet = (props: ControlNetProps) => { transitionDuration: '0.1s', }} > - + { - const { controlNetId, isControlImageProcessed } = props; - const dispatch = useAppDispatch(); - - const handleIsControlImageProcessedToggled = useCallback(() => { - dispatch( - isControlNetImagePreprocessedToggled({ - controlNetId, - }) - ); - }, [controlNetId, dispatch]); - - return ( - - ); -}; - -export default memo(ParamControlNetIsEnabled); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index eda3cde5d2..548acfaea7 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -1,39 +1,45 @@ import { SelectItem } from '@mantine/core'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; type ParamControlNetModelProps = { controlNetId: string; - model: string; }; const ParamControlNetModel = (props: ParamControlNetModelProps) => { - const { controlNetId, model } = props; + const { controlNetId } = props; const dispatch = useAppDispatch(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); - const currentMainModel = useAppSelector( - (state: RootState) => state.generation.model + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ generation, controlNet }) => { + const { model } = generation; + const controlNetModel = controlNet.controlNets[controlNetId]?.model; + return { mainModel: model, controlNetModel }; + }, + defaultSelectorOptions + ), + [controlNetId] ); + const { mainModel, controlNetModel } = useAppSelector(selector); + const { data: controlNetModels } = useGetControlNetModelsQuery(); - const handleModelChanged = useCallback( - (val: string | null) => { - if (!val) return; - dispatch(controlNetModelChanged({ controlNetId, model: val })); - }, - [controlNetId, dispatch] - ); - const data = useMemo(() => { if (!controlNetModels) { return []; @@ -46,7 +52,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { return; } - const disabled = currentMainModel?.base_model !== model.base_model; + const disabled = model?.base_model !== mainModel?.base_model; data.push({ value: id, @@ -60,16 +66,52 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { }); return data; - }, [controlNetModels, currentMainModel?.base_model]); + }, [controlNetModels, mainModel?.base_model]); + + // grab the full model entity from the RTK Query cache + const selectedModel = useMemo( + () => + controlNetModels?.entities[ + `${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}` + ] ?? null, + [ + controlNetModel?.base_model, + controlNetModel?.model_name, + controlNetModels?.entities, + ] + ); + + const handleModelChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + const newControlNetModel = modelIdToControlNetModelParam(v); + + if (!newControlNetModel) { + return; + } + + dispatch( + controlNetModelChanged({ controlNetId, model: newControlNetModel }) + ); + }, + [controlNetId, dispatch] + ); return ( ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 39a321b282..c735a47510 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -1,23 +1,20 @@ -import { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { ImageDTO } from 'services/api/types'; +import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas'; +import { forEach } from 'lodash-es'; +import { imageDeleted } from 'services/api/thunks/image'; +import { isAnySessionRejected } from 'services/api/thunks/session'; +import { appSocketInvocationError } from 'services/events/actions'; +import { controlNetImageProcessed } from './actions'; +import { + CONTROLNET_MODEL_DEFAULT_PROCESSORS, + CONTROLNET_PROCESSORS, +} from './constants'; import { ControlNetProcessorType, RequiredCannyImageProcessorInvocation, RequiredControlNetProcessorNode, } from './types'; -import { - CONTROLNET_MODEL_DEFAULT_PROCESSORS, - // CONTROLNET_MODELS, - CONTROLNET_PROCESSORS, - // ControlNetModelName, -} from './constants'; -import { controlNetImageProcessed } from './actions'; -import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image'; -import { forEach } from 'lodash-es'; -import { isAnySessionRejected } from 'services/api/thunks/session'; -import { appSocketInvocationError } from 'services/events/actions'; export type ControlModes = | 'balanced' @@ -27,7 +24,7 @@ export type ControlModes = export const initialControlNet: Omit = { isEnabled: true, - model: '', + model: null, weight: 1, beginStepPct: 0, endStepPct: 1, @@ -43,7 +40,7 @@ export const initialControlNet: Omit = { export type ControlNetConfig = { controlNetId: string; isEnabled: boolean; - model: string; + model: ControlNetModelParam | null; weight: number; beginStepPct: number; endStepPct: number; @@ -148,7 +145,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - model: string; + model: ControlNetModelParam; }> ) => { const { controlNetId, model } = action.payload; @@ -159,7 +156,7 @@ export const controlNetSlice = createSlice({ let processorType: ControlNetProcessorType | undefined = undefined; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { - if (model.includes(modelSubstring)) { + if (model.model_name.includes(modelSubstring)) { processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } @@ -253,7 +250,11 @@ export const controlNetSlice = createSlice({ let processorType: ControlNetProcessorType | undefined = undefined; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { - if (state.controlNets[controlNetId].model.includes(modelSubstring)) { + if ( + state.controlNets[controlNetId].model?.model_name.includes( + modelSubstring + ) + ) { processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } @@ -287,7 +288,8 @@ export const controlNetSlice = createSlice({ }); builder.addCase(imageDeleted.pending, (state, action) => { - // Preemptively remove the image from the gallery + // Preemptively remove the image from all controlnets + // TODO: doesn't the imageusage stuff do this for us? const { image_name } = action.meta.arg; forEach(state.controlNets, (c) => { if (c.controlImage === image_name) { @@ -300,21 +302,6 @@ export const controlNetSlice = createSlice({ }); }); - // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - // const { image_name, image_url, thumbnail_url } = action.payload; - - // forEach(state.controlNets, (c) => { - // if (c.controlImage?.image_name === image_name) { - // c.controlImage.image_url = image_url; - // c.controlImage.thumbnail_url = thumbnail_url; - // } - // if (c.processedControlImage?.image_name === image_name) { - // c.processedControlImage.image_url = image_url; - // c.processedControlImage.thumbnail_url = thumbnail_url; - // } - // }); - // }); - builder.addCase(appSocketInvocationError, (state, action) => { state.pendingControlImages = []; }); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx index f8c42de60e..b5d9fef312 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx @@ -6,9 +6,10 @@ import { ControlNetModelInputFieldTemplate, ControlNetModelInputFieldValue, } from 'features/nodes/types/types'; -import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; -import { forEach, isString } from 'lodash-es'; -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; @@ -20,15 +21,23 @@ const ControlNetModelInputFieldComponent = ( > ) => { const { nodeId, field } = props; - + const controlNetModel = field.value; const dispatch = useAppDispatch(); const { t } = useTranslation(); const { data: controlNetModels } = useGetControlNetModelsQuery(); + // grab the full model entity from the RTK Query cache const selectedModel = useMemo( - () => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]], - [controlNetModels?.entities, controlNetModels?.ids, field.value] + () => + controlNetModels?.entities[ + `${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}` + ] ?? null, + [ + controlNetModel?.base_model, + controlNetModel?.model_name, + controlNetModels?.entities, + ] ); const data = useMemo(() => { @@ -45,8 +54,8 @@ const ControlNetModelInputFieldComponent = ( data.push({ value: id, - label: model.name, - group: BASE_MODEL_NAME_MAP[model.base_model], + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], }); }); @@ -59,40 +68,32 @@ const ControlNetModelInputFieldComponent = ( return; } + const newControlNetModel = modelIdToControlNetModelParam(v); + + if (!newControlNetModel) { + return; + } + dispatch( fieldValueChanged({ nodeId, fieldName: field.name, - value: v, + value: newControlNetModel, }) ); }, [dispatch, field.name, nodeId] ); - useEffect(() => { - if (field.value && controlNetModels?.ids.includes(field.value)) { - return; - } - - const firstLora = controlNetModels?.ids[0]; - - if (!isString(firstLora)) { - return; - } - - handleValueChanged(firstLora); - }, [field.value, handleValueChanged, controlNetModels?.ids]); - return ( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 8255c65045..eac272ea01 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,6 +1,7 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { + ControlNetModelParam, LoRAModelParam, MainModelParam, VaeModelParam, @@ -81,7 +82,8 @@ const nodesSlice = createSlice({ | ImageField[] | MainModelParam | VaeModelParam - | LoRAModelParam; + | LoRAModelParam + | ControlNetModelParam; }> ) => { const { nodeId, fieldName, value } = action.payload; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 18b837a98e..f111155a39 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1,4 +1,5 @@ import { + ControlNetModelParam, LoRAModelParam, MainModelParam, VaeModelParam, @@ -254,7 +255,7 @@ export type LoRAModelInputFieldValue = FieldValueBase & { export type ControlNetModelInputFieldValue = FieldValueBase & { type: 'controlnet_model'; - value?: string; + value?: ControlNetModelParam; }; export type ArrayInputFieldValue = FieldValueBase & { diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts deleted file mode 100644 index 655d5cd5df..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { BaseModelType, ControlNetModelField } from 'services/api/types'; - -export const modelIdToControlNetModelField = ( - controlNetModelId: string -): ControlNetModelField => { - const [base_model, model_type, model_name] = controlNetModelId.split('/'); - - const field: ControlNetModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx index ee9f7a87bb..f82b02b5af 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEModelSelect.tsx @@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => { return []; } + // add a "default" option, this means use the main model's included VAE const data: SelectItem[] = [ { value: 'default', diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index aa2c60f3a8..9a4b71ce40 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer; */ export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => zLoRAModel.safeParse(val).success; +/** + * Zod schema for ControlNet models + */ +export const zControlNetModel = z.object({ + model_name: z.string().min(1), + base_model: zBaseModel, +}); +/** + * Type alias for model parameter, inferred from its zod schema + */ +export type ControlNetModelParam = z.infer; +/** + * Validates/type-guards a value as a model parameter + */ +export const isValidControlNetModel = ( + val: unknown +): val is ControlNetModelParam => zControlNetModel.safeParse(val).success; /** * Zod schema for l2l strength parameter diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts new file mode 100644 index 0000000000..c08bca0bbc --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts @@ -0,0 +1,30 @@ +import { log } from 'app/logging/useLogger'; +import { zControlNetModel } from 'features/parameters/types/parameterSchemas'; +import { ControlNetModelField } from 'services/api/types'; + +const moduleLog = log.child({ module: 'models' }); + +export const modelIdToControlNetModelParam = ( + controlNetModelId: string +): ControlNetModelField | undefined => { + const [base_model, model_type, model_name] = controlNetModelId.split('/'); + + const result = zControlNetModel.safeParse({ + base_model, + model_name, + }); + + if (!result.success) { + moduleLog.error( + { + controlNetModelId, + errors: result.error.format(), + }, + 'Failed to parse ControlNet model id' + ); + + return; + } + + return result.data; +}; diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts index 2ea7cacb5d..206246c79e 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts @@ -1,9 +1,12 @@ import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; +import { log } from 'app/logging/useLogger'; + +const moduleLog = log.child({ module: 'models' }); export const modelIdToLoRAModelParam = ( - loraId: string + loraModelId: string ): LoRAModelParam | undefined => { - const [base_model, model_type, model_name] = loraId.split('/'); + const [base_model, model_type, model_name] = loraModelId.split('/'); const result = zLoRAModel.safeParse({ base_model, @@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = ( }); if (!result.success) { + moduleLog.error( + { + loraModelId, + errors: result.error.format(), + }, + 'Failed to parse LoRA model id' + ); return; } diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts index b73d3c5f0d..70fb219bed 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts @@ -2,11 +2,14 @@ import { MainModelParam, zMainModel, } from 'features/parameters/types/parameterSchemas'; +import { log } from 'app/logging/useLogger'; + +const moduleLog = log.child({ module: 'models' }); export const modelIdToMainModelParam = ( - modelId: string + mainModelId: string ): MainModelParam | undefined => { - const [base_model, model_type, model_name] = modelId.split('/'); + const [base_model, model_type, model_name] = mainModelId.split('/'); const result = zMainModel.safeParse({ base_model, @@ -14,6 +17,13 @@ export const modelIdToMainModelParam = ( }); if (!result.success) { + moduleLog.error( + { + mainModelId, + errors: result.error.format(), + }, + 'Failed to parse main model id' + ); return; } diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts index 49856531d6..eb57d07f0e 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts @@ -1,9 +1,12 @@ import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; +import { log } from 'app/logging/useLogger'; + +const moduleLog = log.child({ module: 'models' }); export const modelIdToVAEModelParam = ( - modelId: string + vaeModelId: string ): VaeModelParam | undefined => { - const [base_model, model_type, model_name] = modelId.split('/'); + const [base_model, model_type, model_name] = vaeModelId.split('/'); const result = zVaeModel.safeParse({ base_model, @@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = ( }); if (!result.success) { + moduleLog.error( + { + vaeModelId, + errors: result.error.format(), + }, + 'Failed to parse VAE model id' + ); return; } diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 2a9275de73..dc68200408 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1935,12 +1935,12 @@ export type components = { * Width * @description The width to resize to (px) */ - width: number; + width?: number; /** * Height * @description The height to resize to (px) */ - height: number; + height?: number; /** * Resample Mode * @description The resampling mode @@ -3302,7 +3302,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -3922,14 +3922,16 @@ export type components = { latents?: components["schemas"]["LatentsField"]; /** * Width - * @description The width to resize to (px) + * @description The width to resize to (px) + * @default 512 */ - width: number; + width?: number; /** * Height - * @description The height to resize to (px) + * @description The height to resize to (px) + * @default 512 */ - height: number; + height?: number; /** * Mode * @description The interpolation mode @@ -5009,7 +5011,7 @@ export type operations = { /** @description The model imported successfully */ 201: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description The model could not be found */ @@ -5077,14 +5079,14 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Bad request */ @@ -5118,7 +5120,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Bad request */ @@ -5153,7 +5155,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Incompatible models */ From ae72f372be4d1e45027f5f6fa6d5564ae53fdd40 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 13:00:46 +1000 Subject: [PATCH 18/37] fix(nodes): do not use hardcoded controlnet model --- invokeai/app/invocations/controlnet_image_processors.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 06b4d2b2ef..e38417efa1 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -190,11 +190,7 @@ class ControlNetInvocation(BaseInvocation): return ControlOutput( control=ControlField( image=self.image, - #control_model=self.control_model, - control_model=ControlNetModelField( - model_name="canny", - base_model=BaseModelType.StableDiffusion1, - ), + control_model=self.control_model, control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, From d270f21c8509d1e34d075a01e87182dc0aa2b951 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 17:06:17 +1000 Subject: [PATCH 19/37] feat(nodes): valid controlnet weights are -1 to 2 --- .../app/invocations/controlnet_image_processors.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e38417efa1..7eff62a8a5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -125,15 +125,15 @@ class ControlField(BaseModel): # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") - def abs_le_one(cls, v): - """validate that all abs(values) are <=1""" + def validate_control_weight(cls, v): + """Validate that all control weights in the valid range""" if isinstance(v, list): for i in v: - if abs(i) > 1: - raise ValueError('all abs(control_weight) must be <= 1') + if i < -1 or i > 2: + raise ValueError('Control weights must be within -1 to 2 range') else: - if abs(v) > 1: - raise ValueError('abs(control_weight) must be <= 1') + if v < -1 or v > 2: + raise ValueError('Control weights must be within -1 to 2 range') return v class Config: schema_extra = { @@ -165,7 +165,7 @@ class ControlNetInvocation(BaseInvocation): control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - begin_step_percent: float = Field(default=0, ge=0, le=1, + begin_step_percent: float = Field(default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") From 8f66d826a53bdd8d1d26f3c82d0211205e04a8b2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 17:40:40 +1000 Subject: [PATCH 20/37] feat(ui): refactor controlnet UI components to use local memoized selectors makes them more portable and easier to reason about --- .../controlNet/components/ControlNet.tsx | 79 ++++++++----------- .../components/ControlNetImagePreview.tsx | 55 +++++++------ .../ControlNetProcessorComponent.tsx | 29 +++++-- .../ParamControlNetShouldAutoConfig.tsx | 29 +++++-- .../parameters/ParamControlNetBeginEnd.tsx | 29 +++++-- .../parameters/ParamControlNetControlMode.tsx | 21 ++++- .../ParamControlNetProcessorSelect.tsx | 29 ++++--- .../parameters/ParamControlNetWeight.tsx | 20 ++++- .../controlNet/store/controlNetSlice.ts | 16 +++- .../ControlNet/ParamControlNetCollapse.tsx | 2 +- 10 files changed, 200 insertions(+), 109 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index cd3cafa12a..8f90797fd9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -1,10 +1,9 @@ -import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { Box, Flex, useColorMode } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { memo, useCallback } from 'react'; import { FaCopy, FaTrash } from 'react-icons/fa'; import { - ControlNetConfig, - controlNetAdded, + controlNetDuplicated, controlNetRemoved, controlNetToggled, } from '../store/controlNetSlice'; @@ -12,9 +11,13 @@ import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import { ChevronUpIcon } from '@chakra-ui/icons'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIIconButton from 'common/components/IAIIconButton'; import IAISwitch from 'common/components/IAISwitch'; import { useToggle } from 'react-use'; +import { mode } from 'theme/util/mode'; import { v4 as uuidv4 } from 'uuid'; import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetProcessorComponent from './ControlNetProcessorComponent'; @@ -22,30 +25,28 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; -import { mode } from 'theme/util/mode'; - -const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 }; type ControlNetProps = { - controlNet: ControlNetConfig; + controlNetId: string; }; const ControlNet = (props: ControlNetProps) => { - const { - controlNetId, - isEnabled, - model, - weight, - beginStepPct, - endStepPct, - controlMode, - controlImage, - processedControlImage, - processorNode, - processorType, - shouldAutoConfig, - } = props.controlNet; + const { controlNetId } = props; const dispatch = useAppDispatch(); + + const selector = createSelector( + stateSelector, + ({ controlNet }) => { + const { isEnabled, shouldAutoConfig } = + controlNet.controlNets[controlNetId]; + + return { isEnabled, shouldAutoConfig }; + }, + defaultSelectorOptions + ); + + const { isEnabled, shouldAutoConfig } = useAppSelector(selector); + const [isExpanded, toggleIsExpanded] = useToggle(false); const { colorMode } = useColorMode(); const handleDelete = useCallback(() => { @@ -54,9 +55,12 @@ const ControlNet = (props: ControlNetProps) => { const handleDuplicate = useCallback(() => { dispatch( - controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet }) + controlNetDuplicated({ + sourceControlNetId: controlNetId, + newControlNetId: uuidv4(), + }) ); - }, [dispatch, props.controlNet]); + }, [controlNetId, dispatch]); const handleToggleIsEnabled = useCallback(() => { dispatch(controlNetToggled({ controlNetId })); @@ -140,14 +144,8 @@ const ControlNet = (props: ControlNetProps) => { )} - - + + {isEnabled && ( <> @@ -166,13 +164,10 @@ const ControlNet = (props: ControlNetProps) => { > @@ -187,30 +182,24 @@ const ControlNet = (props: ControlNetProps) => { }} > )} - + {isExpanded && ( <> - + )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index 1321f7b0c0..f82e39d34e 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -5,42 +5,51 @@ import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import { memo, useCallback, useMemo, useState } from 'react'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; -import { - ControlNetConfig, - controlNetImageChanged, - controlNetSelector, -} from '../store/controlNetSlice'; - -const selector = createSelector( - controlNetSelector, - (controlNet) => { - const { pendingControlImages } = controlNet; - return { pendingControlImages }; - }, - defaultSelectorOptions -); +import { controlNetImageChanged } from '../store/controlNetSlice'; type Props = { - controlNet: ControlNetConfig; + controlNetId: string; height: SystemStyleObject['h']; }; const ControlNetImagePreview = (props: Props) => { - const { height } = props; - const { - controlNetId, - controlImage: controlImageName, - processedControlImage: processedControlImageName, - processorType, - } = props.controlNet; + const { height, controlNetId } = props; const dispatch = useAppDispatch(); - const { pendingControlImages } = useAppSelector(selector); + + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => { + const { pendingControlImages } = controlNet; + const { controlImage, processedControlImage, processorType } = + controlNet.controlNets[controlNetId]; + + return { + controlImageName: controlImage, + processedControlImageName: processedControlImage, + processorType, + pendingControlImages, + }; + }, + defaultSelectorOptions + ), + [controlNetId] + ); + + const { + controlImageName, + processedControlImageName, + processorType, + pendingControlImages, + } = useAppSelector(selector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx index 4649f89b35..863e42632c 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx @@ -1,10 +1,13 @@ -import { memo } from 'react'; -import { RequiredControlNetProcessorNode } from '../store/types'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { memo, useMemo } from 'react'; import CannyProcessor from './processors/CannyProcessor'; -import HedProcessor from './processors/HedProcessor'; -import LineartProcessor from './processors/LineartProcessor'; -import LineartAnimeProcessor from './processors/LineartAnimeProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; +import HedProcessor from './processors/HedProcessor'; +import LineartAnimeProcessor from './processors/LineartAnimeProcessor'; +import LineartProcessor from './processors/LineartProcessor'; import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor'; import MidasDepthProcessor from './processors/MidasDepthProcessor'; import MlsdImageProcessor from './processors/MlsdImageProcessor'; @@ -15,11 +18,23 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor'; export type ControlNetProcessorProps = { controlNetId: string; - processorNode: RequiredControlNetProcessorNode; }; const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { - const { controlNetId, processorNode } = props; + const { controlNetId } = props; + + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => controlNet.controlNets[controlNetId]?.processorNode, + defaultSelectorOptions + ), + [controlNetId] + ); + + const processorNode = useAppSelector(selector); + if (processorNode.type === 'canny_image_processor') { return ( { - const { controlNetId, shouldAutoConfig } = props; + const { controlNetId } = props; const dispatch = useAppDispatch(); - const isReady = useIsReadyToInvoke(); + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => + controlNet.controlNets[controlNetId]?.shouldAutoConfig, + defaultSelectorOptions + ), + [controlNetId] + ); + + const shouldAutoConfig = useAppSelector(selector); + const isBusy = useAppSelector(selectIsBusy); + const handleShouldAutoConfigChanged = useCallback(() => { dispatch(controlNetAutoConfigToggled({ controlNetId })); }, [controlNetId, dispatch]); @@ -23,7 +38,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => { aria-label="Auto configure processor" isChecked={shouldAutoConfig} onChange={handleShouldAutoConfigChanged} - isDisabled={!isReady} + isDisabled={isBusy} /> ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx index 7d0c53fe40..c08ecf1bb2 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx @@ -10,13 +10,15 @@ import { RangeSliderTrack, Tooltip, } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { controlNetBeginStepPctChanged, controlNetEndStepPctChanged, } from 'features/controlNet/store/controlNetSlice'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; +import { memo, useCallback, useMemo } from 'react'; const SLIDER_MARK_STYLES: ChakraProps['sx'] = { mt: 1.5, @@ -27,17 +29,30 @@ const SLIDER_MARK_STYLES: ChakraProps['sx'] = { type Props = { controlNetId: string; - beginStepPct: number; - endStepPct: number; mini?: boolean; }; const formatPct = (v: number) => `${Math.round(v * 100)}%`; const ParamControlNetBeginEnd = (props: Props) => { - const { controlNetId, beginStepPct, mini = false, endStepPct } = props; + const { controlNetId, mini = false } = props; const dispatch = useAppDispatch(); - const { t } = useTranslation(); + + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => { + const { beginStepPct, endStepPct } = + controlNet.controlNets[controlNetId]; + return { beginStepPct, endStepPct }; + }, + defaultSelectorOptions + ), + [controlNetId] + ); + + const { beginStepPct, endStepPct } = useAppSelector(selector); const handleStepPctChanged = useCallback( (v: number[]) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx index b8737004fd..07b58384e1 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx @@ -1,15 +1,17 @@ -import { useAppDispatch } from 'app/store/storeHooks'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { ControlModes, controlNetControlModeChanged, } from 'features/controlNet/store/controlNetSlice'; -import { useCallback } from 'react'; +import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; type ParamControlNetControlModeProps = { controlNetId: string; - controlMode: string; }; const CONTROL_MODE_DATA = [ @@ -22,8 +24,19 @@ const CONTROL_MODE_DATA = [ export default function ParamControlNetControlMode( props: ParamControlNetControlModeProps ) { - const { controlNetId, controlMode = false } = props; + const { controlNetId } = props; const dispatch = useAppDispatch(); + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => controlNet.controlNets[controlNetId]?.controlMode, + defaultSelectorOptions + ), + [controlNetId] + ); + + const controlMode = useAppSelector(selector); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx index 5d091be1ef..a57de5d70e 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx @@ -1,24 +1,21 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect, { IAISelectDataType, } from 'common/components/IAIMantineSearchableSelect'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { configSelector } from 'features/system/store/configSelectors'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { map } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; -import { - ControlNetProcessorNode, - ControlNetProcessorType, -} from '../../store/types'; +import { ControlNetProcessorType } from '../../store/types'; type ParamControlNetProcessorSelectProps = { controlNetId: string; - processorNode: ControlNetProcessorNode; }; const selector = createSelector( @@ -54,10 +51,22 @@ const selector = createSelector( const ParamControlNetProcessorSelect = ( props: ParamControlNetProcessorSelectProps ) => { - const { controlNetId, processorNode } = props; const dispatch = useAppDispatch(); - const isReady = useIsReadyToInvoke(); + const { controlNetId } = props; + const processorNodeSelector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => ({ + processorNode: controlNet.controlNets[controlNetId]?.processorNode, + }), + defaultSelectorOptions + ), + [controlNetId] + ); + const isBusy = useAppSelector(selectIsBusy); const controlNetProcessors = useAppSelector(selector); + const { processorNode } = useAppSelector(processorNodeSelector); const handleProcessorTypeChanged = useCallback( (v: string | null) => { @@ -77,7 +86,7 @@ const ParamControlNetProcessorSelect = ( value={processorNode.type ?? 'canny_image_processor'} data={controlNetProcessors} onChange={handleProcessorTypeChanged} - disabled={!isReady} + disabled={isBusy} /> ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index c2b77125d0..d0973f3d81 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -1,18 +1,30 @@ -import { useAppDispatch } from 'app/store/storeHooks'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; type ParamControlNetWeightProps = { controlNetId: string; - weight: number; mini?: boolean; }; const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { - const { controlNetId, weight, mini = false } = props; + const { controlNetId, mini = false } = props; const dispatch = useAppDispatch(); + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ controlNet }) => controlNet.controlNets[controlNetId]?.weight, + defaultSelectorOptions + ), + [controlNetId] + ); + const weight = useAppSelector(selector); const handleWeightChanged = useCallback( (weight: number) => { dispatch(controlNetWeightChanged({ controlNetId, weight })); diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index c735a47510..8e6f96add3 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -1,7 +1,7 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas'; -import { forEach } from 'lodash-es'; +import { cloneDeep, forEach } from 'lodash-es'; import { imageDeleted } from 'services/api/thunks/image'; import { isAnySessionRejected } from 'services/api/thunks/session'; import { appSocketInvocationError } from 'services/events/actions'; @@ -84,6 +84,19 @@ export const controlNetSlice = createSlice({ controlNetId, }; }, + controlNetDuplicated: ( + state, + action: PayloadAction<{ + sourceControlNetId: string; + newControlNetId: string; + }> + ) => { + const { sourceControlNetId, newControlNetId } = action.payload; + + const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]); + newControlnet.controlNetId = newControlNetId; + state.controlNets[newControlNetId] = newControlnet; + }, controlNetAddedFromImage: ( state, action: PayloadAction<{ controlNetId: string; controlImage: string }> @@ -315,6 +328,7 @@ export const controlNetSlice = createSlice({ export const { isControlNetEnabledToggled, controlNetAdded, + controlNetDuplicated, controlNetAddedFromImage, controlNetRemoved, controlNetImageChanged, diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 59bf7542eb..201cf860c9 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -55,7 +55,7 @@ const ParamControlNetCollapse = () => { {controlNetsArray.map((c, i) => ( {i > 0 && } - + ))} From 7b6d91c69f170cdf7f9e0391ecbd7588ea86585b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 17:41:35 +1000 Subject: [PATCH 21/37] feat(ui): control net UI weights 0 to 2 --- .../components/parameters/ParamControlNetWeight.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index d0973f3d81..3784829be9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -38,11 +38,11 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { sliderFormLabelProps={{ pb: 2 }} value={weight} onChange={handleWeightChanged} - min={-1} - max={1} + min={0} + max={2} step={0.01} withSliderMarks={!mini} - sliderMarks={[-1, 0, 1]} + sliderMarks={[0, 1, 2]} /> ); }; From 952a7a86741a57477e7ae7559554fcc8d44bae53 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 18:44:53 +1000 Subject: [PATCH 22/37] feat(ui): do not autoprocess if user just disabled autoconfig --- .../listeners/controlNetAutoProcess.ts | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts index dd2fb6f469..a923bd0b60 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts @@ -13,7 +13,11 @@ import { RootState } from 'app/store/store'; const moduleLog = log.child({ namespace: 'controlNet' }); -const predicate: AnyListenerPredicate = (action, state) => { +const predicate: AnyListenerPredicate = ( + action, + state, + prevState +) => { const isActionMatched = controlNetProcessorParamsChanged.match(action) || controlNetModelChanged.match(action) || @@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate = (action, state) => { return false; } + if (controlNetAutoConfigToggled.match(action)) { + // do not process if the user just disabled auto-config + if ( + prevState.controlNet.controlNets[action.payload.controlNetId] + .shouldAutoConfig === true + ) { + return false; + } + } + const { controlImage, processorType, shouldAutoConfig } = state.controlNet.controlNets[action.payload.controlNetId]; From 77ad3c959bbd08d3123ea459d045f066f98706e9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 18:50:06 +1000 Subject: [PATCH 23/37] feat(ui): tweak slider styles --- .../web/src/common/components/IAISlider.tsx | 17 ++--------------- .../frontend/web/src/theme/components/slider.ts | 2 +- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAISlider.tsx b/invokeai/frontend/web/src/common/components/IAISlider.tsx index d99fbfa149..00492b28d6 100644 --- a/invokeai/frontend/web/src/common/components/IAISlider.tsx +++ b/invokeai/frontend/web/src/common/components/IAISlider.tsx @@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next'; import { BiReset } from 'react-icons/bi'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; -const SLIDER_MARK_STYLES: ChakraProps['sx'] = { - mt: 1.5, - fontSize: '2xs', -}; - export type IAIFullSliderProps = { label?: string; value: number; @@ -206,11 +201,7 @@ const IAISlider = (props: IAIFullSliderProps) => { isDisabled={isDisabled} {...sliderFormControlProps} > - {label && ( - - {label} - - )} + {label && {label}} { sx={{ insetInlineStart: '0 !important', insetInlineEnd: 'unset !important', - ...SLIDER_MARK_STYLES, }} {...sliderMarkProps} > @@ -244,7 +234,6 @@ const IAISlider = (props: IAIFullSliderProps) => { sx={{ insetInlineStart: 'unset !important', insetInlineEnd: '0 !important', - ...SLIDER_MARK_STYLES, }} {...sliderMarkProps} > @@ -263,7 +252,6 @@ const IAISlider = (props: IAIFullSliderProps) => { sx={{ insetInlineStart: '0 !important', insetInlineEnd: 'unset !important', - ...SLIDER_MARK_STYLES, }} {...sliderMarkProps} > @@ -278,7 +266,6 @@ const IAISlider = (props: IAIFullSliderProps) => { sx={{ insetInlineStart: 'unset !important', insetInlineEnd: '0 !important', - ...SLIDER_MARK_STYLES, }} {...sliderMarkProps} > @@ -291,7 +278,7 @@ const IAISlider = (props: IAIFullSliderProps) => { key={m} value={m} sx={{ - ...SLIDER_MARK_STYLES, + transform: 'translateX(-50%)', }} {...sliderMarkProps} > diff --git a/invokeai/frontend/web/src/theme/components/slider.ts b/invokeai/frontend/web/src/theme/components/slider.ts index 397dea786a..98a2556b9e 100644 --- a/invokeai/frontend/web/src/theme/components/slider.ts +++ b/invokeai/frontend/web/src/theme/components/slider.ts @@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => { const invokeAIMark = defineStyle((props) => { return { - fontSize: 'xs', + fontSize: '2xs', fontWeight: '500', color: mode('base.700', 'base.400')(props), mt: 2, From 8a14c5db001a527054c7cc93f724394946058cde Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 18:50:37 +1000 Subject: [PATCH 24/37] feat(ui): wip controlnet layout --- .../controlNet/components/ControlNet.tsx | 122 +++++++++--------- .../parameters/ParamControlNetBeginEnd.tsx | 70 ++++------ .../parameters/ParamControlNetWeight.tsx | 6 +- 3 files changed, 88 insertions(+), 110 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index 8f90797fd9..40b4856d99 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -1,4 +1,4 @@ -import { Box, Flex, useColorMode } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { memo, useCallback } from 'react'; import { FaCopy, FaTrash } from 'react-icons/fa'; @@ -17,7 +17,6 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIIconButton from 'common/components/IAIIconButton'; import IAISwitch from 'common/components/IAISwitch'; import { useToggle } from 'react-use'; -import { mode } from 'theme/util/mode'; import { v4 as uuidv4 } from 'uuid'; import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetProcessorComponent from './ControlNetProcessorComponent'; @@ -46,9 +45,8 @@ const ControlNet = (props: ControlNetProps) => { ); const { isEnabled, shouldAutoConfig } = useAppSelector(selector); - const [isExpanded, toggleIsExpanded] = useToggle(false); - const { colorMode } = useColorMode(); + const handleDelete = useCallback(() => { dispatch(controlNetRemoved({ controlNetId })); }, [controlNetId, dispatch]); @@ -72,9 +70,12 @@ const ControlNet = (props: ControlNetProps) => { flexDir: 'column', gap: 2, p: 3, - bg: mode('base.200', 'base.850')(colorMode), borderRadius: 'base', position: 'relative', + bg: 'base.200', + _dark: { + bg: 'base.850', + }, }} > @@ -120,10 +121,13 @@ const ControlNet = (props: ControlNetProps) => { } @@ -136,72 +140,62 @@ const ControlNet = (props: ControlNetProps) => { w: 1.5, h: 1.5, borderRadius: 'full', - bg: mode('error.700', 'error.200')(colorMode), top: 4, insetInlineEnd: 4, + bg: 'error.700', + _dark: { + bg: 'error.200', + }, }} /> )} - - - - - {isEnabled && ( - <> - - - - - - - {!isExpanded && ( - - - - )} - - + + + + + - - {isExpanded && ( - <> - - - - - + {!isExpanded && ( + + + )} + + + + {isExpanded && ( + <> + + + + + + + + )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx index c08ecf1bb2..57cf3f439e 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx @@ -1,5 +1,4 @@ import { - ChakraProps, FormControl, FormLabel, HStack, @@ -20,22 +19,14 @@ import { } from 'features/controlNet/store/controlNetSlice'; import { memo, useCallback, useMemo } from 'react'; -const SLIDER_MARK_STYLES: ChakraProps['sx'] = { - mt: 1.5, - fontSize: '2xs', - fontWeight: '500', - color: 'base.400', -}; - type Props = { controlNetId: string; - mini?: boolean; }; const formatPct = (v: number) => `${Math.round(v * 100)}%`; const ParamControlNetBeginEnd = (props: Props) => { - const { controlNetId, mini = false } = props; + const { controlNetId } = props; const dispatch = useAppDispatch(); const selector = useMemo( @@ -91,38 +82,33 @@ const ParamControlNetBeginEnd = (props: Props) => { - {!mini && ( - <> - - 0% - - - 50% - - - 100% - - - )} + + 0% + + + 50% + + + 100% + diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index 3784829be9..f7a2a24cd5 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -8,11 +8,10 @@ import { memo, useCallback, useMemo } from 'react'; type ParamControlNetWeightProps = { controlNetId: string; - mini?: boolean; }; const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { - const { controlNetId, mini = false } = props; + const { controlNetId } = props; const dispatch = useAppDispatch(); const selector = useMemo( () => @@ -35,13 +34,12 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { return ( ); From 19e076cd1502fe1fd1935a15543d2ae7174a3e25 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:03:32 +1000 Subject: [PATCH 25/37] fix(ui): fix no controlnet model selected by default --- .../ControlNet/ParamControlNetCollapse.tsx | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 201cf860c9..2051eed3e3 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -8,6 +8,7 @@ import ControlNet from 'features/controlNet/components/ControlNet'; import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { controlNetAdded, + controlNetModelChanged, controlNetSelector, } from 'features/controlNet/store/controlNetSlice'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; @@ -15,6 +16,7 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { map } from 'lodash-es'; import { Fragment, memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( @@ -39,10 +41,23 @@ const ParamControlNetCollapse = () => { const { controlNetsArray, activeLabel } = useAppSelector(selector); const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const dispatch = useAppDispatch(); + const { firstModel } = useGetControlNetModelsQuery(undefined, { + selectFromResult: (result) => { + const firstModel = result.data?.entities[result.data?.ids[0]]; + return { + firstModel, + }; + }, + }); const handleClickedAddControlNet = useCallback(() => { - dispatch(controlNetAdded({ controlNetId: uuidv4() })); - }, [dispatch]); + if (!firstModel) { + return; + } + const controlNetId = uuidv4(); + dispatch(controlNetAdded({ controlNetId })); + dispatch(controlNetModelChanged({ controlNetId, model: firstModel })); + }, [dispatch, firstModel]); if (isControlNetDisabled) { return null; @@ -58,7 +73,11 @@ const ParamControlNetCollapse = () => { ))} - + Add ControlNet From 401727b0c95525d5d1714b26e29de77de36a2acb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:23:32 +1000 Subject: [PATCH 26/37] feat(ui): add cnet advanced tooltip --- .../web/src/features/controlNet/components/ControlNet.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index 40b4856d99..b131e8641c 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -114,7 +114,8 @@ const ControlNet = (props: ControlNetProps) => { /> Date: Sat, 15 Jul 2023 19:23:44 +1000 Subject: [PATCH 27/37] feat(ui): move cnet add button to top of list --- .../ControlNet/ParamControlNetCollapse.tsx | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 2051eed3e3..999bfd0de9 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -67,19 +67,20 @@ const ParamControlNetCollapse = () => { + + Add ControlNet + {controlNetsArray.map((c, i) => ( {i > 0 && } ))} - - Add ControlNet - ); From 7dec2d09f091daa2ecd6cf4b839467615799fb98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:51:17 +1000 Subject: [PATCH 28/37] feat(ui): disable specific controlnet inputs when that controlnet is disabled The UX is clearer now, but it's still easy to miss that your individual controlnets are enabled, but the overall controlnet feature is disabled. --- .../controlNet/components/ControlNet.tsx | 8 +++--- .../components/ControlNetImagePreview.tsx | 12 ++++++-- .../ControlNetProcessorComponent.tsx | 26 +++++++++++++++-- .../ParamControlNetShouldAutoConfig.tsx | 11 +++++--- .../parameters/ParamControlNetBeginEnd.tsx | 9 +++--- .../parameters/ParamControlNetControlMode.tsx | 9 ++++-- .../parameters/ParamControlNetIsEnabled.tsx | 28 ------------------- .../parameters/ParamControlNetModel.tsx | 7 +++-- .../ParamControlNetProcessorSelect.tsx | 12 ++++---- .../parameters/ParamControlNetWeight.tsx | 8 ++++-- .../components/processors/CannyProcessor.tsx | 15 ++++++---- .../processors/ContentShuffleProcessor.tsx | 21 ++++++++------ .../components/processors/HedProcessor.tsx | 16 +++++++---- .../processors/LineartAnimeProcessor.tsx | 15 ++++++---- .../processors/LineartProcessor.tsx | 17 ++++++----- .../processors/MediapipeFaceProcessor.tsx | 15 ++++++---- .../processors/MidasDepthProcessor.tsx | 15 ++++++---- .../processors/MlsdImageProcessor.tsx | 19 +++++++------ .../processors/NormalBaeProcessor.tsx | 15 ++++++---- .../processors/OpenposeProcessor.tsx | 17 ++++++----- .../components/processors/PidiProcessor.tsx | 17 ++++++----- .../processors/ZoeDepthProcessor.tsx | 1 + 22 files changed, 182 insertions(+), 131 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetIsEnabled.tsx diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index b131e8641c..82398a7483 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -80,8 +80,8 @@ const ControlNet = (props: ControlNetProps) => { > @@ -143,9 +143,9 @@ const ControlNet = (props: ControlNetProps) => { borderRadius: 'full', top: 4, insetInlineEnd: 4, - bg: 'error.700', + bg: 'accent.700', _dark: { - bg: 'error.200', + bg: 'accent.400', }, }} /> diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index f82e39d34e..bb4ca07bbb 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -29,13 +29,18 @@ const ControlNetImagePreview = (props: Props) => { stateSelector, ({ controlNet }) => { const { pendingControlImages } = controlNet; - const { controlImage, processedControlImage, processorType } = - controlNet.controlNets[controlNetId]; + const { + controlImage, + processedControlImage, + processorType, + isEnabled, + } = controlNet.controlNets[controlNetId]; return { controlImageName: controlImage, processedControlImageName: processedControlImage, processorType, + isEnabled, pendingControlImages, }; }, @@ -49,6 +54,7 @@ const ControlNetImagePreview = (props: Props) => { processedControlImageName, processorType, pendingControlImages, + isEnabled, } = useAppSelector(selector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); @@ -119,6 +125,8 @@ const ControlNetImagePreview = (props: Props) => { h: height, alignItems: 'center', justifyContent: 'center', + pointerEvents: isEnabled ? 'auto' : 'none', + opacity: isEnabled ? 1 : 0.5, }} > { () => createSelector( stateSelector, - ({ controlNet }) => controlNet.controlNets[controlNetId]?.processorNode, + ({ controlNet }) => { + const { isEnabled, processorNode } = + controlNet.controlNets[controlNetId]; + + return { isEnabled, processorNode }; + }, defaultSelectorOptions ), [controlNetId] ); - const processorNode = useAppSelector(selector); + const { isEnabled, processorNode } = useAppSelector(selector); if (processorNode.type === 'canny_image_processor') { return ( ); } if (processorNode.type === 'hed_image_processor') { return ( - + ); } @@ -55,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -64,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -73,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -82,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -91,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -100,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -109,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -118,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -127,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } @@ -136,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { ); } diff --git a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx index 8c0963561c..285fcf7b80 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx @@ -18,14 +18,17 @@ const ParamControlNetShouldAutoConfig = (props: Props) => { () => createSelector( stateSelector, - ({ controlNet }) => - controlNet.controlNets[controlNetId]?.shouldAutoConfig, + ({ controlNet }) => { + const { isEnabled, shouldAutoConfig } = + controlNet.controlNets[controlNetId]; + return { isEnabled, shouldAutoConfig }; + }, defaultSelectorOptions ), [controlNetId] ); - const shouldAutoConfig = useAppSelector(selector); + const { isEnabled, shouldAutoConfig } = useAppSelector(selector); const isBusy = useAppSelector(selectIsBusy); const handleShouldAutoConfigChanged = useCallback(() => { @@ -38,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => { aria-label="Auto configure processor" isChecked={shouldAutoConfig} onChange={handleShouldAutoConfigChanged} - isDisabled={isBusy} + isDisabled={isBusy || !isEnabled} /> ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx index 57cf3f439e..f2f8a8bef2 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx @@ -34,16 +34,16 @@ const ParamControlNetBeginEnd = (props: Props) => { createSelector( stateSelector, ({ controlNet }) => { - const { beginStepPct, endStepPct } = + const { beginStepPct, endStepPct, isEnabled } = controlNet.controlNets[controlNetId]; - return { beginStepPct, endStepPct }; + return { beginStepPct, endStepPct, isEnabled }; }, defaultSelectorOptions ), [controlNetId] ); - const { beginStepPct, endStepPct } = useAppSelector(selector); + const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector); const handleStepPctChanged = useCallback( (v: number[]) => { @@ -61,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => { }, [controlNetId, dispatch]); return ( - + Begin / End Step Percentage { max={1} step={0.01} minStepsBetweenThumbs={5} + isDisabled={!isEnabled} > diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx index 07b58384e1..a1eff1263a 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx @@ -30,13 +30,17 @@ export default function ParamControlNetControlMode( () => createSelector( stateSelector, - ({ controlNet }) => controlNet.controlNets[controlNetId]?.controlMode, + ({ controlNet }) => { + const { controlMode, isEnabled } = + controlNet.controlNets[controlNetId]; + return { controlMode, isEnabled }; + }, defaultSelectorOptions ), [controlNetId] ); - const controlMode = useAppSelector(selector); + const { controlMode, isEnabled } = useAppSelector(selector); const { t } = useTranslation(); @@ -49,6 +53,7 @@ export default function ParamControlNetControlMode( return ( { - const { controlNetId, isEnabled } = props; - const dispatch = useAppDispatch(); - - const handleIsEnabledChanged = useCallback(() => { - dispatch(controlNetToggled({ controlNetId })); - }, [dispatch, controlNetId]); - - return ( - - ); -}; - -export default memo(ParamControlNetIsEnabled); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index 548acfaea7..8392bdd2e3 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -29,14 +29,15 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { ({ generation, controlNet }) => { const { model } = generation; const controlNetModel = controlNet.controlNets[controlNetId]?.model; - return { mainModel: model, controlNetModel }; + const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled; + return { mainModel: model, controlNetModel, isEnabled }; }, defaultSelectorOptions ), [controlNetId] ); - const { mainModel, controlNetModel } = useAppSelector(selector); + const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector); const { data: controlNetModels } = useGetControlNetModelsQuery(); @@ -110,7 +111,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { placeholder="Select a model" value={selectedModel?.id ?? null} onChange={handleModelChanged} - disabled={isBusy} + disabled={isBusy || !isEnabled} tooltip={selectedModel?.description} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx index a57de5d70e..83c66363ac 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx @@ -57,16 +57,18 @@ const ParamControlNetProcessorSelect = ( () => createSelector( stateSelector, - ({ controlNet }) => ({ - processorNode: controlNet.controlNets[controlNetId]?.processorNode, - }), + ({ controlNet }) => { + const { isEnabled, processorNode } = + controlNet.controlNets[controlNetId]; + return { isEnabled, processorNode }; + }, defaultSelectorOptions ), [controlNetId] ); const isBusy = useAppSelector(selectIsBusy); const controlNetProcessors = useAppSelector(selector); - const { processorNode } = useAppSelector(processorNodeSelector); + const { isEnabled, processorNode } = useAppSelector(processorNodeSelector); const handleProcessorTypeChanged = useCallback( (v: string | null) => { @@ -86,7 +88,7 @@ const ParamControlNetProcessorSelect = ( value={processorNode.type ?? 'canny_image_processor'} data={controlNetProcessors} onChange={handleProcessorTypeChanged} - disabled={isBusy} + disabled={isBusy || !isEnabled} /> ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index f7a2a24cd5..8643fd7dad 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -17,13 +17,16 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { () => createSelector( stateSelector, - ({ controlNet }) => controlNet.controlNets[controlNetId]?.weight, + ({ controlNet }) => { + const { weight, isEnabled } = controlNet.controlNets[controlNetId]; + return { weight, isEnabled }; + }, defaultSelectorOptions ), [controlNetId] ); - const weight = useAppSelector(selector); + const { weight, isEnabled } = useAppSelector(selector); const handleWeightChanged = useCallback( (weight: number) => { dispatch(controlNetWeightChanged({ controlNetId, weight })); @@ -33,6 +36,7 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { return ( { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { low_threshold, high_threshold } = processorNode; - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const processorChanged = useProcessorNodeChanged(); const handleLowThresholdChanged = useCallback( @@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => { return ( { withSliderMarks /> { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution, w, h, f } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/HedProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/HedProcessor.tsx index c51a44d1c3..04d6027ff9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/HedProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/HedProcessor.tsx @@ -1,25 +1,29 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import IAISwitch from 'common/components/IAISwitch'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { ChangeEvent, memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor + .default as RequiredHedImageProcessorInvocation; type HedProcessorProps = { controlNetId: string; processorNode: RequiredHedImageProcessorInvocation; + isEnabled: boolean; }; const HedPreprocessor = (props: HedProcessorProps) => { const { controlNetId, processorNode: { detect_resolution, image_resolution, scribble }, + isEnabled, } = props; - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const processorChanged = useProcessorNodeChanged(); const handleDetectResolutionChanged = useCallback( @@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/LineartAnimeProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/LineartAnimeProcessor.tsx index bc64e3c843..90177e7ceb 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/LineartAnimeProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/LineartAnimeProcessor.tsx @@ -1,23 +1,26 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor + .default as RequiredLineartAnimeImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredLineartAnimeImageProcessorInvocation; + isEnabled: boolean; }; const LineartAnimeProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/LineartProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/LineartProcessor.tsx index 11245bf4a7..7f661b1cb9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/LineartProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/LineartProcessor.tsx @@ -1,24 +1,27 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import IAISwitch from 'common/components/IAISwitch'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { ChangeEvent, memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor + .default as RequiredLineartImageProcessorInvocation; type LineartProcessorProps = { controlNetId: string; processorNode: RequiredLineartImageProcessorInvocation; + isEnabled: boolean; }; const LineartProcessor = (props: LineartProcessorProps) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution, coarse } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/MediapipeFaceProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/MediapipeFaceProcessor.tsx index 27aa22ca40..b8b24ce877 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/MediapipeFaceProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/MediapipeFaceProcessor.tsx @@ -1,23 +1,26 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor + .default as RequiredMediapipeFaceProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredMediapipeFaceProcessorInvocation; + isEnabled: boolean; }; const MediapipeFaceProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { max_faces, min_confidence } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleMaxFacesChanged = useCallback( (v: number) => { @@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => { max={20} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { step={0.01} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/MidasDepthProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/MidasDepthProcessor.tsx index ffecb68061..54b174b5df 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/MidasDepthProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/MidasDepthProcessor.tsx @@ -1,23 +1,26 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor + .default as RequiredMidasDepthImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredMidasDepthImageProcessorInvocation; + isEnabled: boolean; }; const MidasDepthProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { a_mult, bg_th } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleAMultChanged = useCallback( (v: number) => { @@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => { step={0.01} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { step={0.01} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/MlsdImageProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/MlsdImageProcessor.tsx index 728c4b44fa..9a86cd2cb5 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/MlsdImageProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/MlsdImageProcessor.tsx @@ -1,23 +1,26 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor + .default as RequiredMlsdImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredMlsdImageProcessorInvocation; + isEnabled: boolean; }; const MlsdImageProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { step={0.01} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { step={0.01} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/NormalBaeProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/NormalBaeProcessor.tsx index d5dce8d492..a0a8ad72cf 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/NormalBaeProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/NormalBaeProcessor.tsx @@ -1,23 +1,26 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor + .default as RequiredNormalbaeImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredNormalbaeImageProcessorInvocation; + isEnabled: boolean; }; const NormalBaeProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/OpenposeProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/OpenposeProcessor.tsx index e97b933b6b..8335a5f75d 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/OpenposeProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/OpenposeProcessor.tsx @@ -1,24 +1,27 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import IAISwitch from 'common/components/IAISwitch'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { ChangeEvent, memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor + .default as RequiredOpenposeImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredOpenposeImageProcessorInvocation; + isEnabled: boolean; }; const OpenposeProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution, hand_and_face } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/PidiProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/PidiProcessor.tsx index 8251447195..4eab83136d 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/PidiProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/PidiProcessor.tsx @@ -1,24 +1,27 @@ +import { useAppSelector } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; import IAISwitch from 'common/components/IAISwitch'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { ChangeEvent, memo, useCallback } from 'react'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import ProcessorWrapper from './common/ProcessorWrapper'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; -const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; +const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor + .default as RequiredPidiImageProcessorInvocation; type Props = { controlNetId: string; processorNode: RequiredPidiImageProcessorInvocation; + isEnabled: boolean; }; const PidiProcessor = (props: Props) => { - const { controlNetId, processorNode } = props; + const { controlNetId, processorNode, isEnabled } = props; const { image_resolution, detect_resolution, scribble, safe } = processorNode; const processorChanged = useProcessorNodeChanged(); - const isReady = useIsReadyToInvoke(); + const isBusy = useAppSelector(selectIsBusy); const handleDetectResolutionChanged = useCallback( (v: number) => { @@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { max={4096} withInput withSliderMarks - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> { label="Safe" isChecked={safe} onChange={handleSafeChanged} - isDisabled={!isReady} + isDisabled={isBusy || !isEnabled} /> ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/processors/ZoeDepthProcessor.tsx b/invokeai/frontend/web/src/features/controlNet/components/processors/ZoeDepthProcessor.tsx index d0a34784bf..b4b45025eb 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/processors/ZoeDepthProcessor.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/processors/ZoeDepthProcessor.tsx @@ -4,6 +4,7 @@ import { memo } from 'react'; type Props = { controlNetId: string; processorNode: RequiredZoeDepthImageProcessorInvocation; + isEnabled: boolean; }; const ZoeDepthProcessor = (props: Props) => { From d1ecd007aba52be5d1435d0264d9daece35576fd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:53:08 +1000 Subject: [PATCH 29/37] feat(ui): promote controlnet to be just under general It is the most impactful feature, and also takes up the most space when you expand it. Promoted. --- .../components/tabs/ImageToImage/ImageToImageTabParameters.tsx | 2 +- .../ui/components/tabs/TextToImage/TextToImageTabParameters.tsx | 2 +- .../components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx index 16c0f44158..87e19993d7 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx @@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => { + - diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx index 987f4ff0bc..ed1c3dd706 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx @@ -20,9 +20,9 @@ const TextToImageTabParameters = () => { + - diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index 6c19a61372..e413ad7ab2 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => { + - From 457e4b7fc572bd7349ddda0af133c8369a344af0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:55:17 +1000 Subject: [PATCH 30/37] feat(ui): tweak slider label spacing --- invokeai/frontend/web/src/common/components/IAISlider.tsx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/common/components/IAISlider.tsx b/invokeai/frontend/web/src/common/components/IAISlider.tsx index 00492b28d6..3e7f38cf91 100644 --- a/invokeai/frontend/web/src/common/components/IAISlider.tsx +++ b/invokeai/frontend/web/src/common/components/IAISlider.tsx @@ -201,7 +201,11 @@ const IAISlider = (props: IAIFullSliderProps) => { isDisabled={isDisabled} {...sliderFormControlProps} > - {label && {label}} + {label && ( + + {label} + + )} Date: Sat, 15 Jul 2023 20:04:33 +1000 Subject: [PATCH 31/37] fix(ui): fix invoke button styles when processing --- .../parameters/components/ProcessButtons/InvokeButton.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx b/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx index 2e399647d8..ab4949392d 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx @@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa'; const IN_PROGRESS_STYLES: ChakraProps['sx'] = { _disabled: { bg: 'none', + color: 'base.600', cursor: 'not-allowed', _hover: { + color: 'base.600', bg: 'none', }, }, From 4ac0ce59fb764691601b4e2ed745ce41850888fc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 15 Jul 2023 20:27:05 +1000 Subject: [PATCH 32/37] fix(ui): add custom label to IAIMantineSelects needed to have their label styles match chakras --- .../components/IAIMantineMultiSelect.tsx | 22 +++++++++++++++--- .../components/IAIMantineSearchableSelect.tsx | 23 ++++++++++++++++--- .../common/components/IAIMantineSelect.tsx | 21 +++++++++++++---- .../controlNet/components/ControlNet.tsx | 5 ++-- .../parameters/ParamControlNetControlMode.tsx | 2 +- .../ParamControlNetProcessorSelect.tsx | 1 + 6 files changed, 60 insertions(+), 14 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index dd5c602150..28c680b824 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -1,17 +1,25 @@ -import { Tooltip } from '@chakra-ui/react'; +import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; -type IAIMultiSelectProps = MultiSelectProps & { +type IAIMultiSelectProps = Omit & { tooltip?: string; inputRef?: RefObject; + label?: string; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, inputRef, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const handleKeyDown = useCallback( @@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( + {label} + + ) : undefined + } ref={inputRef} + disabled={disabled} onKeyDown={handleKeyDown} onKeyUp={handleKeyUp} searchable={searchable} diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx index edf1665bb4..2c3f5434ad 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx @@ -1,4 +1,4 @@ -import { Tooltip } from '@chakra-ui/react'; +import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Select, SelectProps } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; @@ -11,13 +11,22 @@ export type IAISelectDataType = { tooltip?: string; }; -type IAISelectProps = SelectProps & { +type IAISelectProps = Omit & { tooltip?: string; + label?: string; inputRef?: RefObject; }; const IAIMantineSearchableSelect = (props: IAISelectProps) => { - const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + onChange, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const [searchValue, setSearchValue] = useState(''); @@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => { +