Merge branch 'refactor/model-manager2/loader' of github.com:invoke-ai/InvokeAI into refactor/model-manager2/loader

This commit is contained in:
Lincoln Stein 2024-02-14 11:11:00 -05:00
commit 8ac4b9b32c
8 changed files with 49 additions and 49 deletions

View File

@ -1627,7 +1627,7 @@ payload=dict(
queue_batch_id=queue_batch_id, queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
model_key=model_key, model_key=model_key,
submodel=submodel, submodel_type=submodel,
hash=model_info.hash, hash=model_info.hash,
location=str(model_info.location), location=str(model_info.location),
precision=str(model_info.precision), precision=str(model_info.precision),

View File

@ -710,7 +710,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# get the unet's config so that we can pass the base to dispatch_progress() # get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None: def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base) self.dispatch_progress(context, source_node_id, state, unet_config.base)
@ -738,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
): ):
assert isinstance(unet, torch.Tensor) assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -842,7 +842,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
) )
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.Tensor) assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)

View File

@ -21,7 +21,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
@ -113,22 +113,22 @@ class MainModelLoaderInvocation(BaseInvocation):
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
key=key, key=key,
submodel=SubModelType.UNet, submodel_type=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
key=key, key=key,
submodel=SubModelType.Scheduler, submodel_type=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
key=key, key=key,
submodel=SubModelType.Tokenizer, submodel_type=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
key=key, key=key,
submodel=SubModelType.TextEncoder, submodel_type=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
@ -136,7 +136,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=key, key=key,
submodel=SubModelType.Vae, submodel_type=SubModelType.Vae,
), ),
), ),
) )
@ -191,7 +191,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
key=lora_key, key=lora_key,
submodel=None, submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -201,7 +201,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
key=lora_key, key=lora_key,
submodel=None, submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -274,7 +274,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
key=lora_key, key=lora_key,
submodel=None, submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -284,7 +284,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
key=lora_key, key=lora_key,
submodel=None, submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@ -294,7 +294,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2.loras.append( output.clip2.loras.append(
LoraInfo( LoraInfo(
key=lora_key, key=lora_key,
submodel=None, submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )

View File

@ -415,29 +415,29 @@ class OnnxModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_records.exists(model_key): if not context.services.model_manager.store.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
return ONNXModelLoaderOutput( return ONNXModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.UNet, submodel_type=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Scheduler, submodel_type=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Tokenizer, submodel_type=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.TextEncoder, submodel_type=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
@ -445,13 +445,13 @@ class OnnxModelLoaderInvocation(BaseInvocation):
vae_decoder=VaeField( vae_decoder=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.VaeDecoder, submodel_type=SubModelType.VaeDecoder,
), ),
), ),
vae_encoder=VaeField( vae_encoder=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.VaeEncoder, submodel_type=SubModelType.VaeEncoder,
), ),
), ),
) )

View File

@ -47,29 +47,29 @@ class SDXLModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_records.exists(model_key): if not context.services.model_manager.store.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput( return SDXLModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.UNet, submodel_type=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Scheduler, submodel_type=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Tokenizer, submodel_type=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.TextEncoder, submodel_type=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
@ -77,11 +77,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
clip2=ClipField( clip2=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Tokenizer2, submodel_type=SubModelType.Tokenizer2,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.TextEncoder2, submodel_type=SubModelType.TextEncoder2,
), ),
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
@ -89,7 +89,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Vae, submodel_type=SubModelType.Vae,
), ),
), ),
) )
@ -116,29 +116,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_records.exists(model_key): if not context.services.model_manager.store.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput( return SDXLRefinerModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.UNet, submodel_type=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Scheduler, submodel_type=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip2=ClipField( clip2=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Tokenizer2, submodel_type=SubModelType.Tokenizer2,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.TextEncoder2, submodel_type=SubModelType.TextEncoder2,
), ),
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
@ -146,7 +146,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel=SubModelType.Vae, submodel_type=SubModelType.Vae,
), ),
), ),
) )

View File

@ -499,7 +499,7 @@ class ModelManager(object):
model_class=model_class, model_class=model_class,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel_type, submodel_type=submodel_type,
) )
if model_key not in self.cache_keys: if model_key not in self.cache_keys:

View File

@ -303,18 +303,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
in_vram_models = 0 in_vram_models = 0
locked_in_vram_models = 0 locked_in_vram_models = 0
for cache_record in self._cached_models.values(): for cache_record in self._cached_models.values():
assert hasattr(cache_record.model, "device") if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device: if cache_record.model.device == self.storage_device:
in_ram_models += 1 in_ram_models += 1
else: else:
in_vram_models += 1 in_vram_models += 1
if cache_record.locked: if cache_record.locked:
locked_in_vram_models += 1 locked_in_vram_models += 1
self.logger.debug( self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
) )
def make_room(self, model_size: int) -> None: def make_room(self, model_size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.""" """Make enough room in the cache to accommodate a new model of indicated size."""

View File

@ -242,7 +242,7 @@ module = [
"invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.invocation_stats.invocation_stats_default",
"invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_base",
"invokeai.app.services.model_manager.model_manager_default", "invokeai.app.services.model_manager.model_manager_default",
"invokeai.app.services.model_records.model_records_sql", "invokeai.app.services.model_manager.store.model_records_sql",
"invokeai.app.util.controlnet_utils", "invokeai.app.util.controlnet_utils",
"invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.txt2mask",
"invokeai.backend.image_util.safety_checker", "invokeai.backend.image_util.safety_checker",