mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into refactor/model_manager_instantiate
# Conflicts: # invokeai/backend/model_management/model_manager.py
This commit is contained in:
commit
b10cf20eb1
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
import os
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
)
|
)
|
||||||
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.invocation_stats import InvocationStatsService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -128,6 +128,7 @@ class ApiDependencies:
|
|||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -311,6 +312,7 @@ def invoke_cli():
|
|||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
@ -109,12 +109,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -173,7 +176,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -197,12 +200,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -210,8 +216,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@ -247,7 +253,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
return c, c_pooled, None
|
return c, c_pooled, None
|
||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -271,12 +277,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
context.services.model_manager.get_model(
|
(
|
||||||
model_name=name,
|
name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
context.services.model_manager.get_model(
|
||||||
model_type=ModelType.TextualInversion,
|
model_name=name,
|
||||||
context=context,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
).context.model
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -284,8 +293,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@ -357,11 +366,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -415,7 +424,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -467,11 +477,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -525,7 +535,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
|
|||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
|
@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||||
|
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||||
|
|
||||||
|
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||||
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||||
|
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Lora Loader",
|
||||||
|
"tags": ["lora", "loader"],
|
||||||
|
"type_hints": {"lora": "lora_model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
|
base_model = self.lora.base_model
|
||||||
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||||
|
|
||||||
|
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
|
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
|
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||||
|
|
||||||
|
output = SDXLLoraLoaderOutput()
|
||||||
|
|
||||||
|
if self.unet is not None:
|
||||||
|
output.unet = copy.deepcopy(self.unet)
|
||||||
|
output.unet.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip is not None:
|
||||||
|
output.clip = copy.deepcopy(self.clip)
|
||||||
|
output.clip.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip2 is not None:
|
||||||
|
output.clip2 = copy.deepcopy(self.clip2)
|
||||||
|
output.clip2.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class VAEModelField(BaseModel):
|
class VAEModelField(BaseModel):
|
||||||
"""Vae model field"""
|
"""Vae model field"""
|
||||||
|
|
||||||
|
@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@ -76,18 +75,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
# stack.enter_context(
|
(
|
||||||
# context.services.model_manager.get_model(
|
name,
|
||||||
# model_name=name,
|
context.services.model_manager.get_model(
|
||||||
# base_model=self.clip.text_encoder.base_model,
|
model_name=name,
|
||||||
# model_type=ModelType.TextualInversion,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
# )
|
model_type=ModelType.TextualInversion,
|
||||||
# )
|
).context.model,
|
||||||
context.services.model_manager.get_model(
|
)
|
||||||
model_name=name,
|
|
||||||
base_model=self.clip.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
|
@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import Field, validator
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
|
||||||
@ -293,10 +293,20 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
@ -543,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
# apply denoising_start
|
# apply denoising_start
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
|
@ -32,6 +32,7 @@ class InvocationServices:
|
|||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -47,6 +48,7 @@ class InvocationServices:
|
|||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
@ -61,4 +63,5 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
223
invokeai/app/services/invocation_stats.py
Normal file
223
invokeai/app/services/invocation_stats.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
|
from .graph import GraphExecutionState
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
:param graph_execution_manager: Graph execution manager for this session
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
|
|
||||||
|
class StatsContext:
|
||||||
|
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||||
|
self.invocation = invocation
|
||||||
|
self.collector = collector
|
||||||
|
self.graph_id = graph_id
|
||||||
|
self.start_time = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.collector.update_invocation_stats(
|
||||||
|
self.graph_id,
|
||||||
|
self.invocation.type,
|
||||||
|
time.time() - self.start_time,
|
||||||
|
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> StatsContext:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
|
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||||
|
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
self._stats = {}
|
||||||
|
|
||||||
|
def reset_stats(self, graph_execution_id: str):
|
||||||
|
"""Zero the statistics for the indicated graph."""
|
||||||
|
try:
|
||||||
|
self._stats.pop(graph_execution_id)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||||
|
|
||||||
|
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Floating point seconds used by node's exection
|
||||||
|
"""
|
||||||
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
|
stats.calls += 1
|
||||||
|
stats.time_used += time_used
|
||||||
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Send the statistics to the system logger at the info level.
|
||||||
|
Stats will only be printed if when the execution of the graph
|
||||||
|
is complete.
|
||||||
|
"""
|
||||||
|
completed = set()
|
||||||
|
for graph_id, node_log in self._stats.items():
|
||||||
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
if not current_graph_state.is_complete():
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_time = 0
|
||||||
|
logger.info(f"Graph stats: {graph_id}")
|
||||||
|
logger.info("Node Calls Seconds VRAM Used")
|
||||||
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||||
|
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||||
|
total_time += stats.time_used
|
||||||
|
|
||||||
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
|
completed.add(graph_id)
|
||||||
|
|
||||||
|
for graph_id in completed:
|
||||||
|
del self._stats[graph_id]
|
@ -1,14 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import BoundedSemaphore, Event, Thread
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
from .invocation_queue import InvocationQueueItem
|
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
|
||||||
from ..models.exceptions import CanceledException
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ..models.exceptions import CanceledException
|
||||||
|
from .invocation_queue import InvocationQueueItem
|
||||||
|
from .invocation_stats import InvocationStatsServiceBase
|
||||||
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
|
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
@ -83,35 +86,38 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
outputs = invocation.invoke(
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
InvocationContext(
|
outputs = invocation.invoke(
|
||||||
services=self.__invoker.services,
|
InvocationContext(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
services=self.__invoker.services,
|
||||||
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
@ -13,6 +13,7 @@ import requests
|
|||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
import onnx
|
import onnx
|
||||||
|
import torch
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
|
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@ -303,7 +305,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
if "model_index.json" in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
elif "unet/model.onnx" in files:
|
elif "unet/model.onnx" in files:
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
@ -416,15 +418,25 @@ class ModelInstall(object):
|
|||||||
does a save_pretrained() to the indicated staging area.
|
does a save_pretrained() to the indicated staging area.
|
||||||
"""
|
"""
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
precision = torch_dtype(choose_torch_device())
|
||||||
|
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
for revision in revisions:
|
for variant in variants:
|
||||||
try:
|
try:
|
||||||
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
model = DiffusionPipeline.from_pretrained(
|
||||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
repo_id,
|
||||||
pass
|
variant=variant,
|
||||||
|
torch_dtype=precision,
|
||||||
|
safety_checker=None,
|
||||||
|
)
|
||||||
|
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
|
if "fp16" not in str(e):
|
||||||
|
print(e)
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||||
return None
|
return None
|
||||||
|
@ -13,3 +13,4 @@ from .models import (
|
|||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
from .lora import ModelPatcher
|
||||||
|
@ -20,424 +20,6 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# TODO: rename and split this file
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
# rank: Optional[int]
|
|
||||||
# alpha: Optional[float]
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
# layer_key: str
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def scale(self):
|
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
if "alpha" in values:
|
|
||||||
self.alpha = values["alpha"].item()
|
|
||||||
else:
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
||||||
self.bias = torch.sparse_coo_tensor(
|
|
||||||
values["bias_indices"],
|
|
||||||
values["bias_values"],
|
|
||||||
tuple(values["bias_size"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
|
||||||
self.layer_key = layer_key
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
module: torch.nn.Module,
|
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
|
||||||
multiplier: float,
|
|
||||||
):
|
|
||||||
if type(module) == torch.nn.Conv2d:
|
|
||||||
op = torch.nn.functional.conv2d
|
|
||||||
extra_args = dict(
|
|
||||||
stride=module.stride,
|
|
||||||
padding=module.padding,
|
|
||||||
dilation=module.dilation,
|
|
||||||
groups=module.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = torch.nn.functional.linear
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
weight = self.get_weight()
|
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
return (
|
|
||||||
op(
|
|
||||||
*input_h,
|
|
||||||
(weight + bias).view(module.weight.shape),
|
|
||||||
None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
* multiplier
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for val in [self.bias]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
|
||||||
self.down = values["lora_down.weight"]
|
|
||||||
if "lora_mid.weight" in values:
|
|
||||||
self.mid = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.mid is not None:
|
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
||||||
else:
|
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.up, self.mid, self.down]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.mid is not None:
|
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
|
||||||
# w1_a: torch.Tensor
|
|
||||||
# w1_b: torch.Tensor
|
|
||||||
# w2_a: torch.Tensor
|
|
||||||
# w2_b: torch.Tensor
|
|
||||||
# t1: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1_a = values["hada_w1_a"]
|
|
||||||
self.w1_b = values["hada_w1_b"]
|
|
||||||
self.w2_a = values["hada_w2_a"]
|
|
||||||
self.w2_b = values["hada_w2_b"]
|
|
||||||
|
|
||||||
if "hada_t1" in values:
|
|
||||||
self.t1 = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2 = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.t1 is None:
|
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
|
||||||
|
|
||||||
else:
|
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
|
||||||
weight = rebuild1 * rebuild2
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t1 is not None:
|
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
|
||||||
# w1: Optional[torch.Tensor] = None
|
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
|
||||||
# w2: Optional[torch.Tensor] = None
|
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
|
||||||
self.w1 = values["lokr_w1"]
|
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
|
||||||
self.w1_b = values["lokr_w1_b"]
|
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
self.w2 = values["lokr_w2"]
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
self.t2 = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
|
||||||
else:
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
w1 = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
weight = torch.kron(w1, w2)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w1 is not None:
|
|
||||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w2 is not None:
|
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: # (torch.nn.Module):
|
|
||||||
_name: str
|
|
||||||
layers: Dict[str, LoRALayer]
|
|
||||||
_device: torch.device
|
|
||||||
_dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
layers: Dict[str, LoRALayer],
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self._name = name
|
|
||||||
self._device = device or torch.cpu
|
|
||||||
self._dtype = dtype or torch.float32
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> LoRAModel:
|
|
||||||
# TODO: try revert if exception?
|
|
||||||
for key, layer in self.layers.items():
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
self._device = device
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for _, layer in self.layers.items():
|
|
||||||
model_size += layer.calc_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
dtype = dtype or torch.float32
|
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name=file_path.stem, # TODO:
|
|
||||||
layers=dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(state_dict)
|
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
|
||||||
# lora and locon
|
|
||||||
if "lora_down.weight" in values:
|
|
||||||
layer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_b" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO: diff/ia3/... format
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
|
||||||
state_dict[layer_key].clear()
|
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
model.layers[layer_key] = layer
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _group_state(state_dict: dict):
|
|
||||||
state_dict_groupped = dict()
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
stem, leaf = key.split(".", 1)
|
|
||||||
if stem not in state_dict_groupped:
|
|
||||||
state_dict_groupped[stem] = dict()
|
|
||||||
state_dict_groupped[stem][leaf] = value
|
|
||||||
|
|
||||||
return state_dict_groupped
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
(lora_model1, 0.7),
|
(lora_model1, 0.7),
|
||||||
@ -516,6 +98,26 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder2(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||||
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@ -562,7 +164,7 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
@ -572,27 +174,27 @@ class ModelPatcher:
|
|||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti.embedding[i]
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
@ -637,7 +239,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -651,7 +252,6 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@ -828,7 +428,7 @@ class ONNXModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
@ -841,17 +441,17 @@ class ONNXModelPatcher:
|
|||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
@ -861,11 +461,11 @@ class ONNXModelPatcher:
|
|||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
@ -28,8 +28,6 @@ import torch
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
@ -188,7 +186,7 @@ class ModelCache(object):
|
|||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
|
@ -719,7 +719,7 @@ class ModelManager(object):
|
|||||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
# remove conversion cache as config changed
|
# remove conversion cache as config changed
|
||||||
old_model_path = self.app_config.root_path / old_model.path
|
old_model_path = self.resolve_model_path(old_model.path)
|
||||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
if old_model_cache.exists():
|
if old_model_cache.exists():
|
||||||
if old_model_cache.is_dir():
|
if old_model_cache.is_dir():
|
||||||
@ -829,7 +829,7 @@ class ModelManager(object):
|
|||||||
model_type,
|
model_type,
|
||||||
**submodel,
|
**submodel,
|
||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.resolve_model_path(info["path"])
|
||||||
old_diffusers_path = self.resolve_model_path(model.location)
|
old_diffusers_path = self.resolve_model_path(model.location)
|
||||||
new_diffusers_path = (
|
new_diffusers_path = (
|
||||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||||
@ -1041,7 +1041,7 @@ class ModelManager(object):
|
|||||||
model_manager=self,
|
model_manager=self,
|
||||||
prediction_type_helper=ask_user_for_prediction_type,
|
prediction_type_helper=ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||||
directories = {
|
directories = {
|
||||||
config.root_path / x
|
config.root_path / x
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
|
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||||
|
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||||
|
# misclassified as SD-1
|
||||||
|
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||||
|
if key in checkpoint:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||||
|
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
|
||||||
lora_token_vector_length = (
|
lora_token_vector_length = (
|
||||||
checkpoint[key1].shape[1]
|
checkpoint[key1].shape[1]
|
||||||
if key1 in checkpoint
|
if key1 in checkpoint
|
||||||
else checkpoint[key2].shape[0]
|
else checkpoint[key2].shape[1]
|
||||||
if key2 in checkpoint
|
if key2 in checkpoint
|
||||||
else 768
|
else checkpoint[key3].shape[0]
|
||||||
|
if key3 in checkpoint
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if lora_token_vector_length == 768:
|
if lora_token_vector_length == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif lora_token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
else:
|
else:
|
||||||
return None
|
raise InvalidModelException(f"Unknown LoRA type")
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Dict, Union, Literal, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -13,9 +15,6 @@ from .base import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: naming
|
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelFormat(str, Enum):
|
class LoRAModelFormat(str, Enum):
|
||||||
LyCORIS = "lycoris"
|
LyCORIS = "lycoris"
|
||||||
@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
|
|||||||
model = LoRAModelRaw.from_checkpoint(
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=self.model_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
|
base_model=self.base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
self.model_size = model.calc_size()
|
||||||
@ -87,3 +87,582 @@ class LoRAModel(ModelBase):
|
|||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBase:
|
||||||
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
# layer_key: str
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
if "alpha" in values:
|
||||||
|
self.alpha = values["alpha"].item()
|
||||||
|
else:
|
||||||
|
self.alpha = None
|
||||||
|
|
||||||
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
|
self.bias = torch.sparse_coo_tensor(
|
||||||
|
values["bias_indices"],
|
||||||
|
values["bias_values"],
|
||||||
|
tuple(values["bias_size"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.rank = None # set in layer implementation
|
||||||
|
self.layer_key = layer_key
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||||
|
multiplier: float,
|
||||||
|
):
|
||||||
|
if type(module) == torch.nn.Conv2d:
|
||||||
|
op = torch.nn.functional.conv2d
|
||||||
|
extra_args = dict(
|
||||||
|
stride=module.stride,
|
||||||
|
padding=module.padding,
|
||||||
|
dilation=module.dilation,
|
||||||
|
groups=module.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
op = torch.nn.functional.linear
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
weight = self.get_weight()
|
||||||
|
|
||||||
|
bias = self.bias if self.bias is not None else 0
|
||||||
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
return (
|
||||||
|
op(
|
||||||
|
*input_h,
|
||||||
|
(weight + bias).view(module.weight.shape),
|
||||||
|
None,
|
||||||
|
**extra_args,
|
||||||
|
)
|
||||||
|
* multiplier
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for val in [self.bias]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find and debug lora/locon with bias
|
||||||
|
class LoRALayer(LoRALayerBase):
|
||||||
|
# up: torch.Tensor
|
||||||
|
# mid: Optional[torch.Tensor]
|
||||||
|
# down: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.up = values["lora_up.weight"]
|
||||||
|
self.down = values["lora_down.weight"]
|
||||||
|
if "lora_mid.weight" in values:
|
||||||
|
self.mid = values["lora_mid.weight"]
|
||||||
|
else:
|
||||||
|
self.mid = None
|
||||||
|
|
||||||
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.mid is not None:
|
||||||
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
|
else:
|
||||||
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.up, self.mid, self.down]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.up = self.up.to(device=device, dtype=dtype)
|
||||||
|
self.down = self.down.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.mid is not None:
|
||||||
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoHALayer(LoRALayerBase):
|
||||||
|
# w1_a: torch.Tensor
|
||||||
|
# w1_b: torch.Tensor
|
||||||
|
# w2_a: torch.Tensor
|
||||||
|
# w2_b: torch.Tensor
|
||||||
|
# t1: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.w1_a = values["hada_w1_a"]
|
||||||
|
self.w1_b = values["hada_w1_b"]
|
||||||
|
self.w2_a = values["hada_w2_a"]
|
||||||
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
|
||||||
|
if "hada_t1" in values:
|
||||||
|
self.t1 = values["hada_t1"]
|
||||||
|
else:
|
||||||
|
self.t1 = None
|
||||||
|
|
||||||
|
if "hada_t2" in values:
|
||||||
|
self.t2 = values["hada_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.t1 is None:
|
||||||
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
|
else:
|
||||||
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t1 is not None:
|
||||||
|
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoKRLayer(LoRALayerBase):
|
||||||
|
# w1: Optional[torch.Tensor] = None
|
||||||
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
|
# w2: Optional[torch.Tensor] = None
|
||||||
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
if "lokr_w1" in values:
|
||||||
|
self.w1 = values["lokr_w1"]
|
||||||
|
self.w1_a = None
|
||||||
|
self.w1_b = None
|
||||||
|
else:
|
||||||
|
self.w1 = None
|
||||||
|
self.w1_a = values["lokr_w1_a"]
|
||||||
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
|
if "lokr_w2" in values:
|
||||||
|
self.w2 = values["lokr_w2"]
|
||||||
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
else:
|
||||||
|
self.w2 = None
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
|
if "lokr_t2" in values:
|
||||||
|
self.t2 = values["lokr_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
if "lokr_w1_b" in values:
|
||||||
|
self.rank = values["lokr_w1_b"].shape[0]
|
||||||
|
elif "lokr_w2_b" in values:
|
||||||
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
|
else:
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
w1 = self.w1
|
||||||
|
if w1 is None:
|
||||||
|
w1 = self.w1_a @ self.w1_b
|
||||||
|
|
||||||
|
w2 = self.w2
|
||||||
|
if w2 is None:
|
||||||
|
if self.t2 is None:
|
||||||
|
w2 = self.w2_a @ self.w2_b
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w1 is not None:
|
||||||
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
|
_name: str
|
||||||
|
layers: Dict[str, LoRALayer]
|
||||||
|
_device: torch.device
|
||||||
|
_dtype: torch.dtype
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
layers: Dict[str, LoRALayer],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._device = device or torch.cpu
|
||||||
|
self._dtype = dtype or torch.float32
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
# TODO: try revert if exception?
|
||||||
|
for key, layer in self.layers.items():
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for _, layer in self.layers.items():
|
||||||
|
model_size += layer.calc_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||||
|
new_state_dict = dict()
|
||||||
|
for full_key, value in state_dict.items():
|
||||||
|
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
|
continue # clip same
|
||||||
|
|
||||||
|
if not full_key.startswith("lora_unet_"):
|
||||||
|
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||||
|
src_key = full_key.replace("lora_unet_", "")
|
||||||
|
try:
|
||||||
|
dst_key = None
|
||||||
|
while "_" in src_key:
|
||||||
|
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||||
|
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||||
|
break
|
||||||
|
src_key = "_".join(src_key.split("_")[:-1])
|
||||||
|
|
||||||
|
if dst_key is None:
|
||||||
|
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||||
|
new_key = full_key.replace(src_key, dst_key)
|
||||||
|
except:
|
||||||
|
print(SDXL_UNET_COMPVIS_MAP)
|
||||||
|
raise
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
):
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
name=file_path.stem, # TODO:
|
||||||
|
layers=dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path.suffix == ".safetensors":
|
||||||
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
|
||||||
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||||
|
|
||||||
|
for layer_key, values in state_dict.items():
|
||||||
|
# lora and locon
|
||||||
|
if "lora_down.weight" in values:
|
||||||
|
layer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
|
# loha
|
||||||
|
elif "hada_w1_b" in values:
|
||||||
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
|
# lokr
|
||||||
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TODO: ia3/... format
|
||||||
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
|
# lower memory consumption by removing already parsed layer values
|
||||||
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _group_state(state_dict: dict):
|
||||||
|
state_dict_groupped = dict()
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
stem, leaf = key.split(".", 1)
|
||||||
|
if stem not in state_dict_groupped:
|
||||||
|
state_dict_groupped[stem] = dict()
|
||||||
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
return state_dict_groupped
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
def make_sdxl_unet_conversion_map():
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_UNET_COMPVIS_MAP = {
|
||||||
|
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||||
|
for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
}
|
||||||
|
@ -40,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
// default action - just upload and alert user
|
// default action - just upload and alert user
|
||||||
if (postUploadAction?.type === 'TOAST') {
|
if (postUploadAction?.type === 'TOAST') {
|
||||||
const { toastOptions } = postUploadAction;
|
const { toastOptions } = postUploadAction;
|
||||||
if (!autoAddBoardId) {
|
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||||
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
||||||
} else {
|
} else {
|
||||||
// Add this image to the board
|
// Add this image to the board
|
||||||
|
@ -41,6 +41,10 @@ export const gallerySlice = createSlice({
|
|||||||
state.galleryView = 'images';
|
state.galleryView = 'images';
|
||||||
},
|
},
|
||||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||||
|
if (!action.payload) {
|
||||||
|
state.autoAddBoardId = 'none';
|
||||||
|
return;
|
||||||
|
}
|
||||||
state.autoAddBoardId = action.payload;
|
state.autoAddBoardId = action.payload;
|
||||||
},
|
},
|
||||||
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from invokeai.frontend.CLI import invokeai_command_line_interface as main
|
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
main()
|
|
||||||
|
from invokeai.app.cli_app import invoke_cli
|
||||||
|
|
||||||
|
invoke_cli()
|
||||||
|
@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type")
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_path",
|
"model_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
info = ModelProbe().probe(args.model_path)
|
for path in args.model_path:
|
||||||
print(info)
|
info = ModelProbe().probe(path)
|
||||||
|
print(f"{path}: {info}")
|
||||||
|
@ -16,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
from invokeai.app.invocations.collections import RangeInvocation
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
@ -41,6 +42,9 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=sqlite_memory, table_name="graph_executions"
|
||||||
|
)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
@ -51,9 +55,8 @@ def mock_services() -> InvocationServices:
|
|||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=None, # type: ignore
|
configuration=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor
|
|||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
@ -34,6 +35,9 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=sqlite_memory, table_name="graph_executions"
|
||||||
|
)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
@ -44,10 +48,9 @@ def mock_services() -> InvocationServices:
|
|||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
configuration=None, # type: ignore
|
configuration=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user