mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Save models on rescan, uncache model on edit/delete, fixes
This commit is contained in:
@ -3,11 +3,13 @@ import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
@ -52,6 +54,9 @@ class ModelConfigBase(BaseModel):
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None, exclude=True)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
@ -59,7 +64,18 @@ class EmptyConfigLoader(ConfigMixin):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
class ModelBase:
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class classproperty(Generic[T_co]):
|
||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||
self.fget = fget
|
||||
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
|
||||
class ModelBase(metaclass=ABCMeta):
|
||||
#model_path: str
|
||||
#base_model: BaseModelType
|
||||
#model_type: ModelType
|
||||
@ -121,7 +137,7 @@ class ModelBase:
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs):
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "format" not in kwargs:
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
|
||||
@ -129,16 +145,33 @@ class ModelBase:
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classproperty
|
||||
@abstractmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
|
@ -6,6 +6,7 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
@ -43,7 +44,7 @@ class LoRAModel(ModelBase):
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
@ -16,6 +14,7 @@ from .base import (
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
@ -87,7 +86,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@ -198,7 +197,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -6,6 +6,7 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -43,7 +44,7 @@ class TextualInversionModel(ModelBase):
|
||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -12,6 +12,7 @@ from .base import (
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
@ -62,7 +63,7 @@ class VaeModel(ModelBase):
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user