restore 3.9 compatibility by replacing | with Union[]

This commit is contained in:
Lincoln Stein 2023-07-03 10:55:04 -04:00
parent 90aa97edd4
commit 76bafeb99e
16 changed files with 43 additions and 37 deletions

View File

@ -47,7 +47,7 @@ def add_parsers(
commands: list[type], commands: list[type],
command_field: str = "type", command_field: str = "type",
exclude_fields: list[str] = ["id", "type"], exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
): ):
"""Adds parsers for each command to the subparsers""" """Adds parsers for each command to the subparsers"""
@ -72,7 +72,7 @@ def add_parsers(
def add_graph_parsers( def add_graph_parsers(
subparsers, subparsers,
graphs: list[LibraryGraph], graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
): ):
for graph in graphs: for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description) command_parser = subparsers.add_parser(graph.name, help=graph.description)

View File

@ -1,7 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse import argparse
import os
import re import re
import shlex import shlex
import sys import sys
@ -348,7 +347,7 @@ def invoke_cli():
# Parse invocation # Parse invocation
command: CliCommand = None # type:ignore command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None system_graph: Union[LibraryGraph,None] = None
if args['type'] in system_graph_names: if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))

View File

@ -132,7 +132,7 @@ class BoardImagesService(BoardImagesServiceABC):
def board_record_to_dto( def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int board_record: BoardRecord, cover_image_name: Union[str, None], image_count: int
) -> BoardDTO: ) -> BoardDTO:
"""Converts a board record to a board DTO.""" """Converts a board record to a board DTO."""
return BoardDTO( return BoardDTO(

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any from typing import Any, Union
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
@ -28,7 +28,7 @@ class EventServiceBase:
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
progress_image: ProgressImage | None, progress_image: Union[ProgressImage, None],
step: int, step: int,
total_steps: int, total_steps: int,
) -> None: ) -> None:

View File

@ -3,7 +3,6 @@
import copy import copy
import itertools import itertools
import uuid import uuid
from types import NoneType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
InvocationContext, InvocationContext,
) )
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
class EdgeConnection(BaseModel): class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection") node_id: str = Field(description="The id of the node for this edge connection")
@ -846,7 +847,7 @@ class GraphExecutionState(BaseModel):
] ]
} }
def next(self) -> BaseInvocation | None: def next(self) -> Union[BaseInvocation, None]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
@ -80,7 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache: Dict[Path, PILImageType] __cache: Dict[Path, PILImageType]
__max_cache_size: int __max_cache_size: int
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return path return path
def validate_path(self, path: str | Path) -> bool: def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail.""" """Validates the path given for an image or thumbnail."""
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
@ -175,7 +175,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
for folder in folders: for folder in folders:
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None: def __get_cache(self, image_name: Path) -> Union[PILImageType, None]:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):

View File

@ -116,7 +116,7 @@ class ImageRecordStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None: def get_most_recent_image_for_board(self, board_id: str) -> Union[ImageRecord, None]:
"""Gets the most recent image for a board.""" """Gets the most recent image for a board."""
pass pass

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Union
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
@ -22,7 +23,7 @@ class InvocationQueueABC(ABC):
pass pass
@abstractmethod @abstractmethod
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Union[InvocationQueueItem, None]) -> None:
pass pass
@abstractmethod @abstractmethod
@ -57,7 +58,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
return item return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Union[InvocationQueueItem, None]) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None: def cancel(self, graph_execution_state_id: str) -> None:

View File

@ -2,6 +2,7 @@
from abc import ABC from abc import ABC
from threading import Event, Thread from threading import Event, Thread
from typing import Union
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .graph import Graph, GraphExecutionState from .graph import Graph, GraphExecutionState
@ -21,7 +22,7 @@ class Invoker:
def invoke( def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None: ) -> Union[str, None]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -45,7 +46,7 @@ class Invoker:
return invocation.id return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState: def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict from typing import Dict, Union
import torch import torch
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
if name in self.__cache: if name in self.__cache:
del self.__cache[name] del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None: def __get_cache(self, name: str) -> Union[torch.Tensor, None]:
return None if name not in self.__cache else self.__cache[name] return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor): def __set_cache(self, name: str, data: torch.Tensor):
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching""" """Stores latents in a folder on disk without caching"""
__output_folder: str | Path __output_folder: Union[str, Path]
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import Callable, List, Iterator, Optional, Type from typing import Callable, List, Iterator, Optional, Type, Union
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 96, seam_size: int = 96,
seam_blur: int = 16, seam_blur: int = 16,

View File

@ -203,8 +203,8 @@ class Inpaint(Img2Img):
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
strength: float, strength: float,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam

View File

@ -68,7 +68,11 @@ def get_model_config_enums():
enums = list() enums = list()
for model_config in MODEL_CONFIGS: for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
if hasattr(inspect,'get_annotations'):
fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except:

View File

@ -7,7 +7,7 @@ import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field from pydantic import Field
import einops import einops
import PIL.Image import PIL.Image
@ -17,12 +17,11 @@ import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
@ -46,7 +45,7 @@ from .diffusion import (
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
) )
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
@ -105,7 +104,7 @@ class AddsMaskGuidance:
_debug: Optional[Callable] = None _debug: Optional[Callable] = None
def __call__( def __call__(
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
) -> BaseOutput: ) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data. output_class = step_output.__class__ # We'll create a new one with masked data.

View File

@ -4,7 +4,7 @@ import warnings
import weakref import weakref
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Callable from typing import Callable, Union
import torch import torch
from accelerate.utils import send_to_device from accelerate.utils import send_to_device
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
""" """
_hooks: MutableMapping[torch.nn.Module, RemovableHandle] _hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel] _current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device): def __init__(self, execution_device: torch.device):
super().__init__(execution_device) super().__init__(execution_device)

View File

@ -4,6 +4,7 @@ from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from typing import Union
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
@ -49,7 +50,7 @@ def choose_autocast(precision):
return nullcontext return nullcontext
def normalize_device(device: str | torch.device) -> torch.device: def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate.""" """Ensure device has a device index defined, if appropriate."""
device = torch.device(device) device = torch.device(device)
if device.index is None: if device.index is None: