Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager

This commit is contained in:
Brandon Rising 2024-02-14 09:36:30 -05:00 committed by psychedelicious
parent b0835db47d
commit 35e8a33dfd
6 changed files with 29 additions and 29 deletions

View File

@ -1627,7 +1627,7 @@ payload=dict(
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel=submodel,
submodel_type=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),

View File

@ -812,7 +812,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
)
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.Tensor)
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

View File

@ -18,7 +18,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class LoraInfo(ModelInfo):
@ -110,22 +110,22 @@ class MainModelLoaderInvocation(BaseInvocation):
unet=UNetField(
unet=ModelInfo(
key=key,
submodel=SubModelType.UNet,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel=SubModelType.Scheduler,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel=SubModelType.Tokenizer,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel=SubModelType.TextEncoder,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=key,
submodel=SubModelType.Vae,
submodel_type=SubModelType.Vae,
),
),
)
@ -188,7 +188,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel=None,
submodel_type=None,
weight=self.weight,
)
)
@ -198,7 +198,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel=None,
submodel_type=None,
weight=self.weight,
)
)
@ -271,7 +271,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel=None,
submodel_type=None,
weight=self.weight,
)
)
@ -281,7 +281,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel=None,
submodel_type=None,
weight=self.weight,
)
)
@ -291,7 +291,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2.loras.append(
LoraInfo(
key=lora_key,
submodel=None,
submodel_type=None,
weight=self.weight,
)
)

View File

@ -43,29 +43,29 @@ class SDXLModelLoaderInvocation(BaseInvocation):
model_key = self.model.key
# TODO: not found exceptions
if not context.services.model_records.exists(model_key):
if not context.services.model_manager.store.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel=SubModelType.UNet,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel=SubModelType.Scheduler,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
@ -73,11 +73,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer2,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder2,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel=SubModelType.Vae,
submodel_type=SubModelType.Vae,
),
),
)
@ -112,29 +112,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
model_key = self.model.key
# TODO: not found exceptions
if not context.services.model_records.exists(model_key):
if not context.services.model_manager.store.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel=SubModelType.UNet,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel=SubModelType.Scheduler,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer2,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder2,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel=SubModelType.Vae,
submodel_type=SubModelType.Vae,
),
),
)

View File

@ -499,7 +499,7 @@ class ModelManager(object):
model_class=model_class,
base_model=base_model,
model_type=model_type,
submodel=submodel_type,
submodel_type=submodel_type,
)
if model_key not in self.cache_keys:

View File

@ -245,7 +245,7 @@ module = [
"invokeai.app.services.invocation_stats.invocation_stats_default",
"invokeai.app.services.model_manager.model_manager_base",
"invokeai.app.services.model_manager.model_manager_default",
"invokeai.app.services.model_records.model_records_sql",
"invokeai.app.services.model_manager.store.model_records_sql",
"invokeai.app.util.controlnet_utils",
"invokeai.backend.image_util.txt2mask",
"invokeai.backend.image_util.safety_checker",