mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move name/description logic into model_probe.py
This commit is contained in:
parent
9cf060115d
commit
4aab728590
@ -16,6 +16,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop")
|
||||
@ -80,7 +81,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, metadata)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
@ -105,7 +108,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
self._install_queue.join()
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
raise NotImplementedError
|
||||
@ -114,7 +117,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
raise NotImplementedError
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
# The following are internal methods
|
||||
def _create_name(self, model_path: Union[Path, str]) -> str:
|
||||
|
@ -32,6 +32,8 @@ class ModelProbeInfo(object):
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
@ -112,12 +114,16 @@ class ModelProbe(object):
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
name = cls.get_model_name(model_path)
|
||||
description = f"{base_type.value} {model_type.value} model {name}"
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
name = name,
|
||||
description = description,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
@ -141,6 +147,13 @@ class ModelProbe(object):
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user