From 35e8a33dfd4bdf374f1db526484ad8801203b642 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 09:36:30 -0500 Subject: [PATCH] Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager --- docs/contributing/MODEL_MANAGER.md | 2 +- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/model.py | 22 +++++++-------- invokeai/app/invocations/sdxl.py | 28 +++++++++---------- .../backend/model_management/model_manager.py | 2 +- pyproject.toml | 2 +- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b711c654de..b19699de73 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1627,7 +1627,7 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 289da2dd73..c3de521940 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -812,7 +812,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.Tensor) + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f78425c6ee..71a71a63c8 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -18,7 +18,7 @@ from .baseinvocation import ( class ModelInfo(BaseModel): key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -110,22 +110,22 @@ class MainModelLoaderInvocation(BaseInvocation): unet=UNetField( unet=ModelInfo( key=key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -188,7 +188,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -198,7 +198,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -271,7 +271,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -281,7 +281,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -291,7 +291,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip2.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 633a6477fd..85e6fb787f 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,29 +43,29 @@ class SDXLModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -73,11 +73,11 @@ class SDXLModelLoaderInvocation(BaseInvocation): clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -112,29 +112,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index da74ca3fb5..84d93f15fa 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -499,7 +499,7 @@ class ModelManager(object): model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: diff --git a/pyproject.toml b/pyproject.toml index 2958e3629a..f57607bc0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker",