mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/controlnet_cfg_inj_cond
This commit is contained in:
@ -121,8 +121,8 @@ class ModelInstall(object):
|
||||
installed_models = self.mgr.list_models()
|
||||
for md in installed_models:
|
||||
base = md['base_model']
|
||||
model_type = md['type']
|
||||
name = md['name']
|
||||
model_type = md['model_type']
|
||||
name = md['model_name']
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
|
@ -250,8 +250,8 @@ from .model_cache import ModelCache, ModelLocker
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
ModelConfigBase, ModelNotFoundException,
|
||||
)
|
||||
ModelConfigBase, ModelNotFoundException, InvalidModelException,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
@ -275,10 +275,6 @@ class ModelInfo():
|
||||
def __exit__(self,*args, **kwargs):
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
@ -542,9 +538,9 @@ class ModelManager(object):
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
name=cur_model_name,
|
||||
model_name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
model_type=cur_model_type,
|
||||
)
|
||||
|
||||
models.append(model_dict)
|
||||
@ -817,6 +813,8 @@ class ModelManager(object):
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except InvalidModelException:
|
||||
self.logger.warning(f"Not a valid model: {model_path}")
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
@ -37,9 +37,9 @@ MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
@ -56,7 +56,7 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -13,6 +13,7 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
@ -73,10 +74,18 @@ class ControlNetModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
@ -56,10 +57,18 @@ class LoRAModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return LoRAModelFormat.LyCORIS
|
||||
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -16,6 +16,7 @@ from .base import (
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
@ -98,10 +99,18 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -200,10 +209,18 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
SubModelType,
|
||||
classproperty,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -59,7 +60,18 @@ class TextualInversionModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -15,6 +15,7 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
@ -75,10 +76,18 @@ class VaeModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return VaeModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -127,7 +127,7 @@ class AddsMaskGuidance:
|
||||
|
||||
def _t_for_field(self, field_name: str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return self.scheduler.timesteps[-1]
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
|
Reference in New Issue
Block a user