diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b711c654de..b19699de73 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1627,7 +1627,7 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b0419f424f..9d30459518 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -710,7 +710,7 @@ class DenoiseLatentsInvocation(BaseInvocation): source_node_id = graph_execution_state.prepared_source_mapping[self.id] # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) + unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: self.dispatch_progress(context, source_node_id, state, unet_config.base) @@ -738,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): - assert isinstance(unet, torch.Tensor) + assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -842,7 +842,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): ) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.Tensor) + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index a3aaf4c9e1..2381471899 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -21,7 +21,7 @@ from .baseinvocation import ( class ModelInfo(BaseModel): key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -113,22 +113,22 @@ class MainModelLoaderInvocation(BaseInvocation): unet=UNetField( unet=ModelInfo( key=key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -136,7 +136,7 @@ class MainModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -191,7 +191,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -201,7 +201,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -274,7 +274,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -284,7 +284,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -294,7 +294,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip2.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 118e48f89e..ea138bff9f 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -415,29 +415,29 @@ class OnnxModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return ONNXModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -445,13 +445,13 @@ class OnnxModelLoaderInvocation(BaseInvocation): vae_decoder=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.VaeDecoder, + submodel_type=SubModelType.VaeDecoder, ), ), vae_encoder=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.VaeEncoder, + submodel_type=SubModelType.VaeEncoder, ), ), ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 4cb5efbbb6..c38e5448c8 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -47,29 +47,29 @@ class SDXLModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -77,11 +77,11 @@ class SDXLModelLoaderInvocation(BaseInvocation): clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -89,7 +89,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -116,29 +116,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -146,7 +146,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff5..eabbdf819a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -499,7 +499,7 @@ class ModelManager(object): model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 786396062c..02ce1266c7 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -303,18 +303,18 @@ class ModelCache(ModelCacheBase[AnyModel]): in_vram_models = 0 locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - assert hasattr(cache_record.model, "device") - if cache_record.model.device == self.storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 + if hasattr(cache_record.model, "device"): + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 + if cache_record.locked: + locked_in_vram_models += 1 - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" - ) + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + ) def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" diff --git a/pyproject.toml b/pyproject.toml index 09713a1cbe..243b0b1f21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,7 +242,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker",