diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8dbeaa3d05..c298114cbc 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,6 +1,7 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein +import pathlib from typing import Literal, List, Optional, Union from fastapi import Body, Path, Query, Response @@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @@ -78,7 +80,7 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, @@ -94,7 +96,7 @@ async def import_model( prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL """ + """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } @@ -126,18 +128,100 @@ async def import_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) +@models_router.post( + "/add", + operation_id="add_model", + responses= { + 201: {"description" : "The model added successfully"}, + 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def add_model( + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> ImportModelResponse: + """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + + logger = ApiDependencies.invoker.services.logger + try: + ApiDependencies.invoker.services.model_manager.add_model( + info.model_name, + info.base_model, + info.model_type, + model_attributes = info.dict() + ) + logger.info(f'Successfully added {info.model_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.model_name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + +@models_router.post( + "/rename/{base_model}/{model_type}/{model_name}", + operation_id="rename_model", + responses= { + 201: {"description" : "The model was renamed successfully"}, + 404: {"description" : "The model could not be found"}, + 409: {"description" : "There is already a model corresponding to the new name"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def rename_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="current model name"), + new_name: Optional[str] = Query(description="new model name", default=None), + new_base: Optional[BaseModelType] = Query(description="new model base", default=None), +) -> ImportModelResponse: + """ Rename a model""" + + logger = ApiDependencies.invoker.services.logger + + try: + result = ApiDependencies.invoker.services.model_manager.rename_model( + base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + logger.debug(result) + logger.info(f'Successfully renamed {model_name}=>{new_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=new_name or model_name, + base_model=new_base or base_model, + model_type=model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ - 204: { - "description": "Model deleted successfully" - }, - 404: { - "description": "Model not found" - } + 204: { "description": "Model deleted successfully" }, + 404: { "description": "Model not found" } }, + status_code = 204, + response_model = None, ) async def delete_model( base_model: BaseModelType = Path(description="Base model"), @@ -173,14 +257,17 @@ async def convert_model( base_model: BaseModelType = Path(description="Base model"), model_type: ModelType = Path(description="The type of model"), model_name: str = Path(description="model name"), + convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), ) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model""" + """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") + dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None ApiDependencies.invoker.services.model_manager.convert_model(model_name, base_model = base_model, - model_type = model_type + model_type = model_type, + convert_dest_directory = dest, ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, base_model = base_model, @@ -191,6 +278,53 @@ async def convert_model( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + +@models_router.get( + "/search", + operation_id="search_for_models", + responses={ + 200: { "description": "Directory searched successfully" }, + 404: { "description": "Invalid directory path" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def search_for_models( + search_path: pathlib.Path = Query(description="Directory path to search for models") +)->List[pathlib.Path]: + if not search_path.is_dir(): + raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + +@models_router.get( + "/ckpt_confs", + operation_id="list_ckpt_configs", + responses={ + 200: { "description" : "paths retrieved successfully" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def list_ckpt_configs( +)->List[pathlib.Path]: + """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" + return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() + + +@models_router.get( + "/sync", + operation_id="sync_to_config", + responses={ + 201: { "description": "synchronization successful" }, + }, + status_code = 201, + response_model = None +) +async def sync_to_config( +)->None: + """Call after making changes to models.yaml, autoimport directories or models directory to synchronize + in-memory data structures with disk data structures.""" + return ApiDependencies.invoker.services.model_manager.sync_to_config() @models_router.put( "/merge/{base_model}", @@ -210,17 +344,21 @@ async def merge_models( alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), + merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names}") + logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, base_model, - merged_model_name or "+".join(model_names), - alpha, - interp, - force) + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory = dest + ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, base_model = base_model, model_type = ModelType.Main, diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 303e0a0c84..a5a9701149 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation): text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + truncate_long_prompts=False, ) conjunction = Compel.parse_prompt_string(self.prompt) @@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation): c, options = compel.build_conditioning_tensor_for_prompt_object( prompt) - # TODO: long prompt support - # if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( tokens_count_including_eos_bos=get_max_token_count( tokenizer, conjunction), diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c37dcda998..7eff62a8a5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict from PIL import Image from pydantic import BaseModel, Field, validator +from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, @@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] +class ControlNetModelField(BaseModel): + """ControlNet model field""" + + model_name: str = Field(description="Name of the ControlNet model") + base_model: BaseModelType = Field(description="Base model") + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") - control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -118,15 +125,15 @@ class ControlField(BaseModel): # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") - def abs_le_one(cls, v): - """validate that all abs(values) are <=1""" + def validate_control_weight(cls, v): + """Validate that all control weights in the valid range""" if isinstance(v, list): for i in v: - if abs(i) > 1: - raise ValueError('all abs(control_weight) must be <= 1') + if i < -1 or i > 2: + raise ValueError('Control weights must be within -1 to 2 range') else: - if abs(v) > 1: - raise ValueError('abs(control_weight) must be <= 1') + if v < -1 or v > 2: + raise ValueError('Control weights must be within -1 to 2 range') return v class Config: schema_extra = { @@ -134,6 +141,7 @@ class ControlField(BaseModel): "ui": { "type_hints": { "control_weight": "float", + "control_model": "controlnet_model", # "control_weight": "number", } } @@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation): type: Literal["controlnet"] = "controlnet" # Inputs image: ImageField = Field(default=None, description="The control image") - control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - begin_step_percent: float = Field(default=0, ge=0, le=1, + begin_step_percent: float = Field(default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3f95f3658..baf78c7c23 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops @@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.models.base import ModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState @@ -71,16 +73,21 @@ def get_scheduler( scheduler_name: str, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim']) + scheduler_name, SCHEDULER_MAP['ddim'] + ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict()) + **scheduler_info.dict() + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, ** - scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = { + **scheduler_config, + **scheduler_extra_config, + "_backup": scheduler_config, + } scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py @@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, - intermediate_state: PipelineIntermediateState) -> None: + self, + context: InvocationContext, + source_node_id: str, + intermediate_state: PipelineIntermediateState, + ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation): ) def get_conditioning_data( - self, context: InvocationContext, scheduler) -> ConditioningData: + self, + context: InvocationContext, + scheduler, + ) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get( - self.positive_conditioning.conditioning_name) + self.positive_conditioning.conditioning_name + ) uc, _ = context.services.latents.get( - self.negative_conditioning.conditioning_name) + self.negative_conditioning.conditioning_name + ) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation): return conditioning_data def create_pipeline( - self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + self, + unet, + scheduler, + ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( # unet, @@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation): model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], + exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: @@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation): control_data = [] control_models = [] for control_info in control_list: - # handle control models - if ("," in control_info.control_model): - control_model_split = control_info.control_model.split(",") - control_name = control_model_split[0] - control_subfolder = control_model_split[1] - print("Using HF model subfolders") - print(" control_name: ", control_name) - print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained( - control_name, subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to( - model.device) - else: - control_model = ControlNetModel.from_pretrained( - control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) + control_model = exit_stack.enter_context( + context.services.model_manager.get_model( + model_name=control_info.control_model.model_name, + model_type=ModelType.ControlNet, + base_model=control_info.control_model.base_model, + ) + ) + control_models.append(control_model) control_image_field = control_info.image input_image = context.services.images.get_pil_image( - control_image_field.image_name) + control_image_field.image_name + ) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation): weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode,) + control_mode=control_info.control_mode, + ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data @@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size @@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype) + latent, device=unet.device, dtype=latent.dtype + ) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 1b1c43dc11..67db5c9478 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -19,7 +19,7 @@ from invokeai.backend.model_management import ( ModelMerger, MergeInterpolationMethod, ) - +from invokeai.backend.model_management.model_search import FindModels import torch from invokeai.app.models.exceptions import CanceledException @@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str, + ): + """ + Rename the indicated model. + """ + pass + + @abstractmethod + def list_checkpoint_configs( + self + )->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + pass + @abstractmethod def convert_model( self, @@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -228,9 +250,26 @@ class ModelManagerServiceBase(ABC): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ pass - + + @abstractmethod + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + pass + + @abstractmethod + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. + as well. """ self.logger.debug(f'delete model {model_name}') self.mgr.del_model(model_name, base_model, model_type) + self.mgr.commit() def convert_model( self, model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase): :param model_name: Name of the model to convert :param base_model: Base model type :param model_type: Type of model ['vae' or 'main'] + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) This will raise a ValueError unless the model is not a checkpoint. It will also raise a ValueError in the event that there is a similarly-named diffusers directory already in place. """ self.logger.debug(f'convert model {model_name}') - return self.mgr.convert_model(model_name, base_model, model_type) + return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) def commit(self, conf_file: Optional[Path]=None): """ @@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ merger = ModelMerger(self.mgr) try: @@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase): alpha = alpha, interp = interp, force = force, + merge_dest_directory=merge_dest_directory, ) except AssertionError as e: raise ValueError(e) return result + + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + search = FindModels(directory,self.logger) + return search.list_models() + + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + return self.mgr.sync_to_config() + + def list_checkpoint_configs(self)->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + config = self.mgr.app_config + conf_path = config.legacy_conf_path + root_path = config.root_path + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] + + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + """ + Rename the indicated model. Can provide a new name and/or a new base. + :param model_name: Current name of the model + :param base_model: Current base of the model + :param model_type: Model type (can't be changed) + :param new_name: New name for the model + :param new_base: New base for the model + """ + self.mgr.rename_model(base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index c1fa30c3b7..559dac6f61 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -71,8 +71,6 @@ class ModelInstallList: class InstallSelections(): install_models: List[str]= field(default_factory=list) remove_models: List[str]=field(default_factory=list) -# scan_directory: Path = None -# autoscan_on_startup: bool=False @dataclass class ModelLoadInfo(): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index a38fcf6c24..1fcd9148cd 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, Chdir from .model_cache import ModelCache, ModelLocker +from .model_search import ModelSearch from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, @@ -322,16 +323,7 @@ class ModelManager(object): self.config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: metadata not found # TODO: version check - - self.models = dict() - for model_key, model_config in config.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary + self.app_config = InvokeAIAppConfig.get_config() self.logger = logger self.cache = ModelCache( @@ -342,11 +334,41 @@ class ModelManager(object): sequential_offload = sequential_offload, logger = logger, ) + + self._read_models(config) + + def _read_models(self, config: Optional[DictConfig] = None): + if not config: + if self.config_path: + config = OmegaConf.load(self.config_path) + else: + return + + self.models = dict() + for model_key, model_config in config.items(): + if model_key.startswith('_'): + continue + model_name, base_model, model_type = self.parse_key(model_key) + model_class = MODEL_CLASSES[base_model][model_type] + # alias for config file + model_config["model_format"] = model_config.pop("format") + self.models[model_key] = model_class.create_config(**model_config) + + # check config version number and update on disk/RAM if necessary self.cache_keys = dict() # add controlnet, lora and textual_inversion models from disk self.scan_models_directory() + def sync_to_config(self): + """ + Call this when `models.yaml` has been changed externally. + This will reinitialize internal data structures + """ + # Reread models directory; note that this will reinitialize the cache, + # causing otherwise unreferenced models to be removed from memory + self._read_models() + def model_exists( self, model_name: str, @@ -527,7 +549,10 @@ class ModelManager(object): model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) models = [] for model_key in model_keys: - model_config = self.models[model_key] + model_config = self.models.get(model_key) + if not model_config: + self.logger.error(f'Unknown model {model_name}') + raise KeyError(f'Unknown model {model_name}') cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) if base_model is not None and cur_base_model != base_model: @@ -646,11 +671,61 @@ class ModelManager(object): config = model_config, ) + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + ''' + Rename or rebase a model. + ''' + if new_name is None and new_base is None: + self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") + return + + model_key = self.create_key(model_name, base_model, model_type) + model_cfg = self.models.get(model_key, None) + if not model_cfg: + raise KeyError(f"Unknown model: {model_key}") + + old_path = self.app_config.root_path / model_cfg.path + new_name = new_name or model_name + new_base = new_base or base_model + new_key = self.create_key(new_name, new_base, model_type) + if new_key in self.models: + raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') + + # if this is a model file/directory that we manage ourselves, we need to move it + if old_path.is_relative_to(self.app_config.models_path): + new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name + move(old_path, new_path) + model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) + + # clean up caches + old_model_cache = self._get_model_cache_path(old_path) + if old_model_cache.exists(): + if old_model_cache.is_dir(): + rmtree(str(old_model_cache)) + else: + old_model_cache.unlink() + + cache_ids = self.cache_keys.pop(model_key, []) + for cache_id in cache_ids: + self.cache.uncache_model(cache_id) + + self.models.pop(model_key, None) # delete + self.models[new_key] = model_cfg + self.commit() + def convert_model ( self, model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + dest_directory: Optional[Path]=None, ) -> AddModelResult: ''' Convert a checkpoint file into a diffusers folder, deleting the cached @@ -677,14 +752,14 @@ class ModelManager(object): ) checkpoint_path = self.app_config.root_path / info["path"] old_diffusers_path = self.app_config.models_path / model.location - new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name + new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name if new_diffusers_path.exists(): raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") try: move(old_diffusers_path,new_diffusers_path) info["model_format"] = "diffusers" - info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) + info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path)) info.pop('config') result = self.add_model(model_name, base_model, model_type, @@ -824,6 +899,7 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. @@ -831,62 +907,41 @@ class ModelManager(object): # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.frontend.install.model_install import ask_user_for_prediction_type - + + class ScanAndImport(ModelSearch): + def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): + super().__init__(directories, logger) + self.installer = installer + self.ignore = ignore + + def on_search_started(self): + self.new_models_found = dict() + + def on_model_found(self, model: Path): + if model not in self.ignore: + self.new_models_found.update(self.installer.heuristic_import(model)) + + def on_search_completed(self): + self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models') + + def models_found(self): + return self.new_models_found + + installer = ModelInstall(config = self.app_config, model_manager = self, prediction_type_helper = ask_user_for_prediction_type, ) - - scanned_dirs = set() - config = self.app_config - known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} - - for autodir in [config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir]: - if autodir is None: - continue - - installed = dict() - - autodir = self.app_config.root_path / autodir - if not autodir.exists(): - continue - - items_scanned = 0 - new_models_found = dict() - - for root, dirs, files in os.walk(autodir): - items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in known_paths or path.parent in scanned_dirs: - scanned_dirs.add(path) - continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): - try: - new_models_found.update(installer.heuristic_import(path)) - scanned_dirs.add(path) - except ValueError as e: - self.logger.warning(str(e)) - - for f in files: - path = Path(root) / f - if path in known_paths or path.parent in scanned_dirs: - continue - if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - try: - import_result = installer.heuristic_import(path) - new_models_found.update(import_result) - except ValueError as e: - self.logger.warning(str(e)) - - installed.update(new_models_found) - - self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') - return installed + known_paths = {config.root_path / x['path'] for x in self.list_models()} + directories = {config.root_path / x for x in [config.autoimport_dir, + config.lora_dir, + config.embedding_dir, + config.controlnet_dir] + } + scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) + scanner.search() + return scanner.models_found() def heuristic_import(self, items_to_import: Set[str], @@ -924,3 +979,4 @@ class ModelManager(object): successfully_installed.update(installed) self.commit() return successfully_installed + diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 39f951d2b4..6427b9e430 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -11,7 +11,7 @@ from enum import Enum from pathlib import Path from diffusers import DiffusionPipeline from diffusers import logging as dlogging -from typing import List, Union +from typing import List, Union, Optional import invokeai.backend.util.logging as logger @@ -74,6 +74,7 @@ class ModelMerger(object): alpha: float = 0.5, interp: MergeInterpolationMethod = None, force: bool = False, + merge_dest_directory: Optional[Path] = None, **kwargs, ) -> AddModelResult: """ @@ -85,7 +86,7 @@ class ModelMerger(object): :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ @@ -111,7 +112,7 @@ class ModelMerger(object): merged_pipe = self.merge_diffusion_models( model_paths, alpha, merge_method, force, **kwargs ) - dump_path = config.models_path / base_model.value / ModelType.Main.value + dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py new file mode 100644 index 0000000000..1e282b4bb8 --- /dev/null +++ b/invokeai/backend/model_management/model_search.py @@ -0,0 +1,103 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class for recursive directory search for models. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Set, types +from pathlib import Path + +import invokeai.backend.util.logging as logger + +class ModelSearch(ABC): + def __init__(self, directories: List[Path], logger: types.ModuleType=logger): + """ + Initialize a recursive model directory search. + :param directories: List of directory Paths to recurse through + :param logger: Logger to use + """ + self.directories = directories + self.logger = logger + self._items_scanned = 0 + self._models_found = 0 + self._scanned_dirs = set() + self._scanned_paths = set() + self._pruned_paths = set() + + @abstractmethod + def on_search_started(self): + """ + Called before the scan starts. + """ + pass + + @abstractmethod + def on_model_found(self, model: Path): + """ + Process a found model. Raise an exception if something goes wrong. + :param model: Model to process - could be a directory or checkpoint. + """ + pass + + @abstractmethod + def on_search_completed(self): + """ + Perform some activity when the scan is completed. May use instance + variables, items_scanned and models_found + """ + pass + + def search(self): + self.on_search_started() + for dir in self.directories: + self.walk_directory(dir) + self.on_search_completed() + + def walk_directory(self, path: Path): + for root, dirs, files in os.walk(path): + if str(Path(root).name).startswith('.'): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self._items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path in self._scanned_paths or path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + try: + self.on_model_found(path) + self._models_found += 1 + self._scanned_dirs.add(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: + try: + self.on_model_found(path) + self._models_found += 1 + except Exception as e: + self.logger.warning(str(e)) + +class FindModels(ModelSearch): + def on_search_started(self): + self.models_found: Set[Path] = set() + + def on_model_found(self,model: Path): + self.models_found.add(model) + + def on_search_completed(self): + pass + + def list_models(self) -> List[Path]: + self.search() + return self.models_found + + diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1c573b26b6..e404c56bdf 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items(): model_configs.discard(None) MODEL_CONFIGS.extend(model_configs) - for cfg in model_configs: + # LS: sort to get the checkpoint configs first, which makes + # for a better template in the Swagger docs + for cfg in sorted(model_configs, key=lambda x: str(x)): model_name, cfg_name = cfg.__qualname__.split('.')[-2:] openapi_cfg_name = model_name + cfg_name if openapi_cfg_name in vars(): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index ddbc401e5b..c569872a81 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel): path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) - # do not save to config error: Optional[ModelError] = Field(None) class Config: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 74751a40dd..3d2e50d8fb 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel): vae: Optional[str] = Field(None) config: str variant: ModelVariantType - - + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 1175475bba..307e949ef8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent: def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path + + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat([ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + cond = torch.cat([ + cond, + torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), + ], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat([ + encoder_attention_mask, + conditioning_attention_mask, + ]) + + return cond, encoder_attention_mask + x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, sigma_twice, both_conditionings, + encoder_attention_mask=encoder_attention_mask, + **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts index dd2fb6f469..a923bd0b60 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts @@ -13,7 +13,11 @@ import { RootState } from 'app/store/store'; const moduleLog = log.child({ namespace: 'controlNet' }); -const predicate: AnyListenerPredicate = (action, state) => { +const predicate: AnyListenerPredicate = ( + action, + state, + prevState +) => { const isActionMatched = controlNetProcessorParamsChanged.match(action) || controlNetModelChanged.match(action) || @@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate = (action, state) => { return false; } + if (controlNetAutoConfigToggled.match(action)) { + // do not process if the user just disabled auto-config + if ( + prevState.controlNet.controlNets[action.payload.controlNetId] + .shouldAutoConfig === true + ) { + return false; + } + } + const { controlImage, processorType, shouldAutoConfig } = state.controlNet.controlNets[action.payload.controlNetId]; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index ee879a8915..05076960fb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { forEach } from 'lodash-es'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -51,7 +52,14 @@ export const addModelSelectedListener = () => { modelsCleared += 1; } - // TODO: handle incompatible controlnet; pending model manager support + const { controlNets } = state.controlNet; + forEach(controlNets, (controlNet, controlNetId) => { + if (controlNet.model?.base_model !== base_model) { + dispatch(controlNetRemoved({ controlNetId })); + modelsCleared += 1; + } + }); + if (modelsCleared > 0) { dispatch( addToast( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index f8abcfa758..5e3caa7c99 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -11,6 +11,7 @@ import { import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => { matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, effect: async (action, { getState, dispatch }) => { // ControlNet models loaded - need to remove missing ControlNets from state - // TODO: pending model manager controlnet support + const controlNets = getState().controlNet.controlNets; + + forEach(controlNets, (controlNet, controlNetId) => { + const isControlNetAvailable = some( + action.payload.entities, + (m) => + m?.model_name === controlNet?.model?.model_name && + m?.base_model === controlNet?.model?.base_model + ); + + if (isControlNetAvailable) { + return; + } + + dispatch(controlNetRemoved({ controlNetId })); + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 40b8c1c73a..be642a6435 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,5 +1,5 @@ import { - CONTROLNET_MODELS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; @@ -128,7 +128,7 @@ export type AppConfig = { canRestoreDeletedImagesFromBin: boolean; sd: { defaultModel?: string; - disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; + disabledControlNetModels: string[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; iterations: { initial: number; diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 59a1d281fe..991398f5a0 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => { )} {!imageDTO && isUploadDisabled && noContentFallback} - - {imageDTO && ( + {!isDropDisabled && ( + + )} + {imageDTO && !isDragDisabled && ( & { tooltip?: string; inputRef?: RefObject; + label?: string; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, inputRef, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const handleKeyDown = useCallback( @@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( + {label} + + ) : undefined + } ref={inputRef} + disabled={disabled} onKeyDown={handleKeyDown} onKeyUp={handleKeyUp} searchable={searchable} diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx index edf1665bb4..2c3f5434ad 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx @@ -1,4 +1,4 @@ -import { Tooltip } from '@chakra-ui/react'; +import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Select, SelectProps } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; @@ -11,13 +11,22 @@ export type IAISelectDataType = { tooltip?: string; }; -type IAISelectProps = SelectProps & { +type IAISelectProps = Omit & { tooltip?: string; + label?: string; inputRef?: RefObject; }; const IAIMantineSearchableSelect = (props: IAISelectProps) => { - const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + onChange, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const [searchValue, setSearchValue] = useState(''); @@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => { +