diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b711c654de..b19699de73 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1627,7 +1627,7 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b0419f424f..d628e7b49c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -842,7 +842,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): ) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.Tensor) + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index a3aaf4c9e1..2381471899 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -21,7 +21,7 @@ from .baseinvocation import ( class ModelInfo(BaseModel): key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -113,22 +113,22 @@ class MainModelLoaderInvocation(BaseInvocation): unet=UNetField( unet=ModelInfo( key=key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -136,7 +136,7 @@ class MainModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -191,7 +191,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -201,7 +201,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -274,7 +274,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -284,7 +284,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -294,7 +294,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip2.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 118e48f89e..ea138bff9f 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -415,29 +415,29 @@ class OnnxModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return ONNXModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -445,13 +445,13 @@ class OnnxModelLoaderInvocation(BaseInvocation): vae_decoder=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.VaeDecoder, + submodel_type=SubModelType.VaeDecoder, ), ), vae_encoder=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.VaeEncoder, + submodel_type=SubModelType.VaeEncoder, ), ), ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 4cb5efbbb6..c38e5448c8 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -47,29 +47,29 @@ class SDXLModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -77,11 +77,11 @@ class SDXLModelLoaderInvocation(BaseInvocation): clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -89,7 +89,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -116,29 +116,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -146,7 +146,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff5..eabbdf819a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -499,7 +499,7 @@ class ModelManager(object): model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: diff --git a/pyproject.toml b/pyproject.toml index 09713a1cbe..243b0b1f21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,7 +242,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker",