move name/description logic into model_probe.py

This commit is contained in:
Lincoln Stein 2023-11-22 22:29:02 -05:00
parent 9cf060115d
commit 4aab728590
2 changed files with 19 additions and 3 deletions

View File

@ -16,6 +16,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.model_manager.hash import FastModelHash
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop")
@ -80,7 +81,9 @@ class ModelInstallService(ModelInstallServiceBase):
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
raise NotImplementedError
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, metadata)
return self._register(model_path, info)
def install_path(
self,
@ -105,7 +108,7 @@ class ModelInstallService(ModelInstallServiceBase):
raise NotImplementedError
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
raise NotImplementedError
self._install_queue.join()
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
raise NotImplementedError
@ -114,7 +117,7 @@ class ModelInstallService(ModelInstallServiceBase):
raise NotImplementedError
def hash(self, model_path: Union[Path, str]) -> str:
raise NotImplementedError
return FastModelHash.hash(model_path)
# The following are internal methods
def _create_name(self, model_path: Union[Path, str]) -> str:

View File

@ -32,6 +32,8 @@ class ModelProbeInfo(object):
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
class ProbeBase(object):
@ -112,12 +114,16 @@ class ModelProbe(object):
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name = name,
description = description,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
@ -141,6 +147,13 @@ class ModelProbe(object):
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):