rename installer __del__() to stop(). Improve probe error messages

This commit is contained in:
Lincoln Stein 2023-12-10 12:55:01 -05:00
parent 913c68982a
commit 2f3457c02a
4 changed files with 12 additions and 11 deletions

View File

@ -7,7 +7,6 @@ from .model_install_base import (
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase, ModelInstallServiceBase,
ModelSource, ModelSource,
ModelSourceValidator,
UnknownInstallJobException, UnknownInstallJobException,
URLModelSource, URLModelSource,
) )
@ -20,7 +19,6 @@ __all__ = [
"ModelInstallJob", "ModelInstallJob",
"UnknownInstallJobException", "UnknownInstallJobException",
"ModelSource", "ModelSource",
"ModelSourceValidator",
"LocalModelSource", "LocalModelSource",
"HFModelSource", "HFModelSource",
"URLModelSource", "URLModelSource",

View File

@ -123,8 +123,6 @@ class URLModelSource(StringLikeSource):
# https://github.com/tiangolo/fastapi/discussions/9287 # https://github.com/tiangolo/fastapi/discussions/9287
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")] ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
ModelSourceValidator = TypeAdapter(ModelSource)
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
@ -147,7 +145,7 @@ class ModelInstallJob(BaseModel):
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception.""" """Record the error and traceback from an exception."""
self.error_type = e.__class__.__name__ self.error_type = e.__class__.__name__
self.error = traceback.format_exc() self.error = "".join(traceback.format_exception(e))
self.status = InstallStatus.ERROR self.status = InstallStatus.ERROR
@ -173,6 +171,10 @@ class ModelInstallServiceBase(ABC):
"""Call at InvokeAI startup time.""" """Call at InvokeAI startup time."""
self.sync_to_config() self.sync_to_config()
@abstractmethod
def stop(self) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property @property
@abstractmethod @abstractmethod
def app_config(self) -> InvokeAIAppConfig: def app_config(self) -> InvokeAIAppConfig:

View File

@ -73,10 +73,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._models_installed = set() self._models_installed = set()
self._start_installer_thread() self._start_installer_thread()
def __del__(self) -> None:
"""At GC time, we stop the install thread and release its resources."""
self._install_queue.put(STOP_JOB)
@property @property
def app_config(self) -> InvokeAIAppConfig: # noqa D102 def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config return self._app_config
@ -89,6 +85,10 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus return self._event_bus
def stop(self) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() threading.Thread(target=self._install_next_item, daemon=True).start()
@ -114,6 +114,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_errored(job, excp) self._signal_job_errored(job, excp)
finally: finally:
self._install_queue.task_done() self._install_queue.task_done()
self._logger.info("Install thread exiting")
def _signal_job_running(self, job: ModelInstallJob) -> None: def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING job.status = InstallStatus.RUNNING

View File

@ -425,7 +425,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
elif token_dim == 1280: elif token_dim == 1280:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
else: else:
raise InvalidModelConfigException("Could not determine base type") raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
class ControlNetCheckpointProbe(CheckpointProbeBase): class ControlNetCheckpointProbe(CheckpointProbeBase):
@ -443,7 +443,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024: elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
raise InvalidModelConfigException("Unable to determine base type for {self.checkpoint_path}") raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
class IPAdapterCheckpointProbe(CheckpointProbeBase): class IPAdapterCheckpointProbe(CheckpointProbeBase):