Fix loading diffusers ti

This commit is contained in:
Sergey Borisov 2023-07-05 19:46:00 +03:00
parent 818616a0c5
commit 0ac9dca926
6 changed files with 22 additions and 7 deletions

View File

@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction,
FlattenedPrompt, Fragment) FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation):
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
) )
except Exception: except ModelNotFoundException:
# print(e) # print(e)
#import traceback #import traceback
# print(traceback.format_exc()) #print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\

View File

@ -655,6 +655,9 @@ class TextualInversionModel:
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelConfigBase, ModelNotFoundException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -409,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models: if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type) self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models: if model_key not in self.models:
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
@ -421,7 +421,7 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override # vae/movq override
# TODO: # TODO:

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path, file_path=checkpoint_path,
dtype=torch_dtype, dtype=torch_dtype,
) )