mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix loading diffusers ti (#3661)
This commit is contained in:
commit
dd946790ec
@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction,
|
|||||||
FlattenedPrompt, Fragment)
|
FlattenedPrompt, Fragment)
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except Exception:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
#import traceback
|
#import traceback
|
||||||
# print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
@ -655,6 +655,9 @@ class TextualInversionModel:
|
|||||||
else:
|
else:
|
||||||
result.embedding = next(iter(state_dict.values()))
|
result.embedding = next(iter(state_dict.values()))
|
||||||
|
|
||||||
|
if len(result.embedding.shape) == 1:
|
||||||
|
result.embedding = result.embedding.unsqueeze(0)
|
||||||
|
|
||||||
if not isinstance(result.embedding, torch.Tensor):
|
if not isinstance(result.embedding, torch.Tensor):
|
||||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
|
|||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
ModelConfigBase,
|
ModelConfigBase, ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
@ -409,7 +409,7 @@ class ModelManager(object):
|
|||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
model_path = self.app_config.root_path / model_config.path
|
model_path = self.app_config.root_path / model_config.path
|
||||||
@ -421,7 +421,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
|
@ -2,7 +2,7 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
class ModelNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
|
@ -8,6 +8,7 @@ from .base import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
classproperty,
|
classproperty,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in textual inversion")
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
checkpoint_path = self.model_path
|
||||||
|
if os.path.isdir(checkpoint_path):
|
||||||
|
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
model = TextualInversionModelRaw.from_checkpoint(
|
model = TextualInversionModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=checkpoint_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user