mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): fix up mm service types
This commit is contained in:
parent
5d099f4a49
commit
5d4d0e795c
@ -2,22 +2,26 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
|
||||||
ModelManager,
|
|
||||||
BaseModelType,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
ModelInfo,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from .config import InvokeAIAppConfig
|
from invokeai.backend.model_management.model_manager import (AddModelResult,
|
||||||
|
BaseModelType,
|
||||||
|
ModelInfo,
|
||||||
|
ModelManager,
|
||||||
|
ModelType,
|
||||||
|
SubModelType)
|
||||||
|
from invokeai.backend.model_management.models.base import \
|
||||||
|
SchedulerPredictionType
|
||||||
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
from .config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
@ -30,7 +34,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -137,9 +141,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -159,7 +163,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -173,7 +177,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
logger: types.ModuleType,
|
logger: ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -387,9 +391,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
)->Dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user