mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restore 3.9 compatibility by replacing | with Union[]
This commit is contained in:
parent
90aa97edd4
commit
76bafeb99e
@ -47,7 +47,7 @@ def add_parsers(
|
|||||||
commands: list[type],
|
commands: list[type],
|
||||||
command_field: str = "type",
|
command_field: str = "type",
|
||||||
exclude_fields: list[str] = ["id", "type"],
|
exclude_fields: list[str] = ["id", "type"],
|
||||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||||
):
|
):
|
||||||
"""Adds parsers for each command to the subparsers"""
|
"""Adds parsers for each command to the subparsers"""
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ def add_parsers(
|
|||||||
def add_graph_parsers(
|
def add_graph_parsers(
|
||||||
subparsers,
|
subparsers,
|
||||||
graphs: list[LibraryGraph],
|
graphs: list[LibraryGraph],
|
||||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||||
):
|
):
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
@ -348,7 +347,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
# Parse invocation
|
# Parse invocation
|
||||||
command: CliCommand = None # type:ignore
|
command: CliCommand = None # type:ignore
|
||||||
system_graph: LibraryGraph|None = None
|
system_graph: Union[LibraryGraph,None] = None
|
||||||
if args['type'] in system_graph_names:
|
if args['type'] in system_graph_names:
|
||||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||||
|
@ -132,7 +132,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(
|
def board_record_to_dto(
|
||||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
board_record: BoardRecord, cover_image_name: Union[str, None], image_count: int
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
"""Converts a board record to a board DTO."""
|
"""Converts a board record to a board DTO."""
|
||||||
return BoardDTO(
|
return BoardDTO(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||||
@ -28,7 +28,7 @@ class EventServiceBase:
|
|||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
progress_image: ProgressImage | None,
|
progress_image: Union[ProgressImage, None],
|
||||||
step: int,
|
step: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import uuid
|
import uuid
|
||||||
from types import NoneType
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# in 3.10 this would be "from types import NoneType"
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
class EdgeConnection(BaseModel):
|
class EdgeConnection(BaseModel):
|
||||||
node_id: str = Field(description="The id of the node for this edge connection")
|
node_id: str = Field(description="The id of the node for this edge connection")
|
||||||
@ -846,7 +847,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def next(self) -> BaseInvocation | None:
|
def next(self) -> Union[BaseInvocation, None]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
@ -80,7 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__cache: Dict[Path, PILImageType]
|
__cache: Dict[Path, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
|
|
||||||
def __init__(self, output_folder: str | Path):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def validate_path(self, path: str | Path) -> bool:
|
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||||
"""Validates the path given for an image or thumbnail."""
|
"""Validates the path given for an image or thumbnail."""
|
||||||
path = path if isinstance(path, Path) else Path(path)
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
return path.exists()
|
return path.exists()
|
||||||
@ -175,7 +175,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
for folder in folders:
|
for folder in folders:
|
||||||
folder.mkdir(parents=True, exist_ok=True)
|
folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def __get_cache(self, image_name: Path) -> PILImageType | None:
|
def __get_cache(self, image_name: Path) -> Union[PILImageType, None]:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||||
|
@ -116,7 +116,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
def get_most_recent_image_for_board(self, board_id: str) -> Union[ImageRecord, None]:
|
||||||
"""Gets the most recent image for a board."""
|
"""Gets the most recent image for a board."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
class InvocationQueueItem(BaseModel):
|
||||||
@ -22,7 +23,7 @@ class InvocationQueueABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -57,7 +58,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||||
self.__queue.put(item)
|
self.__queue.put(item)
|
||||||
|
|
||||||
def cancel(self, graph_execution_state_id: str) -> None:
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .graph import Graph, GraphExecutionState
|
from .graph import Graph, GraphExecutionState
|
||||||
@ -21,7 +22,7 @@ class Invoker:
|
|||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||||
) -> str | None:
|
) -> Union[str, None]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ class Invoker:
|
|||||||
|
|
||||||
return invocation.id
|
return invocation.id
|
||||||
|
|
||||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
|
||||||
"""Creates a new execution state for the given graph"""
|
"""Creates a new execution state for the given graph"""
|
||||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
if name in self.__cache:
|
if name in self.__cache:
|
||||||
del self.__cache[name]
|
del self.__cache[name]
|
||||||
|
|
||||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
def __get_cache(self, name: str) -> Union[torch.Tensor, None]:
|
||||||
return None if name not in self.__cache else self.__cache[name]
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
class DiskLatentsStorage(LatentsStorageBase):
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
"""Stores latents in a folder on disk without caching"""
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
__output_folder: str | Path
|
__output_folder: Union[str, Path]
|
||||||
|
|
||||||
def __init__(self, output_folder: str | Path):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import Callable, List, Iterator, Optional, Type
|
from typing import Callable, List, Iterator, Optional, Type, Union
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
init_image: Image.Image | torch.FloatTensor,
|
init_image: Union[Image.Image, torch.FloatTensor],
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image.Image | torch.FloatTensor,
|
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 96,
|
seam_size: int = 96,
|
||||||
seam_blur: int = 16,
|
seam_blur: int = 16,
|
||||||
|
@ -203,8 +203,8 @@ class Inpaint(Img2Img):
|
|||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
conditioning,
|
conditioning,
|
||||||
init_image: Image.Image | torch.FloatTensor,
|
init_image: Union[Image.Image, torch.FloatTensor],
|
||||||
mask_image: Image.Image | torch.FloatTensor,
|
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||||
strength: float,
|
strength: float,
|
||||||
mask_blur_radius: int = 8,
|
mask_blur_radius: int = 8,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
|
@ -68,7 +68,11 @@ def get_model_config_enums():
|
|||||||
enums = list()
|
enums = list()
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
for model_config in MODEL_CONFIGS:
|
||||||
fields = inspect.get_annotations(model_config)
|
|
||||||
|
if hasattr(inspect,'get_annotations'):
|
||||||
|
fields = inspect.get_annotations(model_config)
|
||||||
|
else:
|
||||||
|
fields = model_config.__annotations__
|
||||||
try:
|
try:
|
||||||
field = fields["model_format"]
|
field = fields["model_format"]
|
||||||
except:
|
except:
|
||||||
|
@ -7,7 +7,7 @@ import secrets
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -17,12 +17,11 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
@ -46,7 +45,7 @@ from .diffusion import (
|
|||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
@ -105,7 +104,7 @@ class AddsMaskGuidance:
|
|||||||
_debug: Optional[Callable] = None
|
_debug: Optional[Callable] = None
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
|
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
|
||||||
) -> BaseOutput:
|
) -> BaseOutput:
|
||||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import warnings
|
|||||||
import weakref
|
import weakref
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from typing import Callable
|
from typing import Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import send_to_device
|
from accelerate.utils import send_to_device
|
||||||
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||||
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
|
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
||||||
|
|
||||||
def __init__(self, execution_device: torch.device):
|
def __init__(self, execution_device: torch.device):
|
||||||
super().__init__(execution_device)
|
super().__init__(execution_device)
|
||||||
|
@ -4,6 +4,7 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
from typing import Union
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
@ -49,7 +50,7 @@ def choose_autocast(precision):
|
|||||||
return nullcontext
|
return nullcontext
|
||||||
|
|
||||||
|
|
||||||
def normalize_device(device: str | torch.device) -> torch.device:
|
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||||
"""Ensure device has a device index defined, if appropriate."""
|
"""Ensure device has a device index defined, if appropriate."""
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
if device.index is None:
|
if device.index is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user