Save models on rescan, uncache model on edit/delete, fixes

This commit is contained in:
Sergey Borisov
2023-06-14 03:12:12 +03:00
parent 26090011c4
commit 740c05a0bb
8 changed files with 88 additions and 299 deletions

View File

@ -3,11 +3,13 @@ import sys
import typing
import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
@ -52,6 +54,9 @@ class ModelConfigBase(BaseModel):
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
@ -59,7 +64,18 @@ class EmptyConfigLoader(ConfigMixin):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
T_co = TypeVar('T_co', covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
@ -121,7 +137,7 @@ class ModelBase:
return cls.__configs
@classmethod
def create_config(cls, **kwargs):
def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs:
raise Exception("Field 'format' not found in model config")
@ -129,16 +145,33 @@ class ModelBase:
return configs[kwargs["format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs):
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):

View File

@ -6,6 +6,7 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
@ -43,7 +44,7 @@ class LoRAModel(ModelBase):
self.model_size = model.calc_size()
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False

View File

@ -1,7 +1,5 @@
import os
import json
import torch
import safetensors.torch
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@ -16,6 +14,7 @@ from .base import (
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
@ -87,7 +86,7 @@ class StableDiffusion1Model(DiffusersModel):
variant=variant,
)
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return True
@ -198,7 +197,7 @@ class StableDiffusion2Model(DiffusersModel):
upcast_attention=upcast_attention,
)
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return True

View File

@ -6,6 +6,7 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -43,7 +44,7 @@ class TextualInversionModel(ModelBase):
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False

View File

@ -12,6 +12,7 @@ from .base import (
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
@ -62,7 +63,7 @@ class VaeModel(ModelBase):
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
@classproperty
def save_to_config(cls) -> bool:
return False