mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'refactor/model-manager2/loader' of github.com:invoke-ai/InvokeAI into refactor/model-manager2/loader
This commit is contained in:
commit
8ac4b9b32c
@ -1627,7 +1627,7 @@ payload=dict(
|
|||||||
queue_batch_id=queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_key=model_key,
|
model_key=model_key,
|
||||||
submodel=submodel,
|
submodel_type=submodel,
|
||||||
hash=model_info.hash,
|
hash=model_info.hash,
|
||||||
location=str(model_info.location),
|
location=str(model_info.location),
|
||||||
precision=str(model_info.precision),
|
precision=str(model_info.precision),
|
||||||
|
@ -710,7 +710,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||||
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
|
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||||
@ -738,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, torch.Tensor)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -842,7 +842,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.Tensor)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
@ -21,7 +21,7 @@ from .baseinvocation import (
|
|||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
@ -113,22 +113,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel=SubModelType.UNet,
|
submodel_type=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel_type=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel=SubModelType.Tokenizer,
|
submodel_type=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel=SubModelType.TextEncoder,
|
submodel_type=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
@ -136,7 +136,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel=SubModelType.Vae,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -191,7 +191,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
key=lora_key,
|
key=lora_key,
|
||||||
submodel=None,
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -201,7 +201,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
key=lora_key,
|
key=lora_key,
|
||||||
submodel=None,
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -274,7 +274,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
key=lora_key,
|
key=lora_key,
|
||||||
submodel=None,
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -284,7 +284,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
key=lora_key,
|
key=lora_key,
|
||||||
submodel=None,
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -294,7 +294,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip2.loras.append(
|
output.clip2.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
key=lora_key,
|
key=lora_key,
|
||||||
submodel=None,
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -415,29 +415,29 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_records.exists(model_key):
|
if not context.services.model_manager.store.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
return ONNXModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.UNet,
|
submodel_type=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel_type=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Tokenizer,
|
submodel_type=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.TextEncoder,
|
submodel_type=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
@ -445,13 +445,13 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
|||||||
vae_decoder=VaeField(
|
vae_decoder=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.VaeDecoder,
|
submodel_type=SubModelType.VaeDecoder,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
vae_encoder=VaeField(
|
vae_encoder=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.VaeEncoder,
|
submodel_type=SubModelType.VaeEncoder,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -47,29 +47,29 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_records.exists(model_key):
|
if not context.services.model_manager.store.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
return SDXLModelLoaderOutput(
|
return SDXLModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.UNet,
|
submodel_type=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel_type=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Tokenizer,
|
submodel_type=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.TextEncoder,
|
submodel_type=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
@ -77,11 +77,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
clip2=ClipField(
|
clip2=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Tokenizer2,
|
submodel_type=SubModelType.Tokenizer2,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.TextEncoder2,
|
submodel_type=SubModelType.TextEncoder2,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
@ -89,7 +89,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Vae,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -116,29 +116,29 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_records.exists(model_key):
|
if not context.services.model_manager.store.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
return SDXLRefinerModelLoaderOutput(
|
return SDXLRefinerModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.UNet,
|
submodel_type=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel_type=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip2=ClipField(
|
clip2=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Tokenizer2,
|
submodel_type=SubModelType.Tokenizer2,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.TextEncoder2,
|
submodel_type=SubModelType.TextEncoder2,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
@ -146,7 +146,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel=SubModelType.Vae,
|
submodel_type=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -499,7 +499,7 @@ class ModelManager(object):
|
|||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel_type,
|
submodel_type=submodel_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_key not in self.cache_keys:
|
if model_key not in self.cache_keys:
|
||||||
|
@ -303,18 +303,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
in_vram_models = 0
|
in_vram_models = 0
|
||||||
locked_in_vram_models = 0
|
locked_in_vram_models = 0
|
||||||
for cache_record in self._cached_models.values():
|
for cache_record in self._cached_models.values():
|
||||||
assert hasattr(cache_record.model, "device")
|
if hasattr(cache_record.model, "device"):
|
||||||
if cache_record.model.device == self.storage_device:
|
if cache_record.model.device == self.storage_device:
|
||||||
in_ram_models += 1
|
in_ram_models += 1
|
||||||
else:
|
else:
|
||||||
in_vram_models += 1
|
in_vram_models += 1
|
||||||
if cache_record.locked:
|
if cache_record.locked:
|
||||||
locked_in_vram_models += 1
|
locked_in_vram_models += 1
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_room(self, model_size: int) -> None:
|
def make_room(self, model_size: int) -> None:
|
||||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
|
@ -242,7 +242,7 @@ module = [
|
|||||||
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||||
"invokeai.app.services.model_manager.model_manager_base",
|
"invokeai.app.services.model_manager.model_manager_base",
|
||||||
"invokeai.app.services.model_manager.model_manager_default",
|
"invokeai.app.services.model_manager.model_manager_default",
|
||||||
"invokeai.app.services.model_records.model_records_sql",
|
"invokeai.app.services.model_manager.store.model_records_sql",
|
||||||
"invokeai.app.util.controlnet_utils",
|
"invokeai.app.util.controlnet_utils",
|
||||||
"invokeai.backend.image_util.txt2mask",
|
"invokeai.backend.image_util.txt2mask",
|
||||||
"invokeai.backend.image_util.safety_checker",
|
"invokeai.backend.image_util.safety_checker",
|
||||||
|
Loading…
Reference in New Issue
Block a user