restore 3.9 compatibility by replacing | with Union[]

This commit is contained in:
Lincoln Stein
2023-07-03 10:55:04 -04:00
parent 2465c7987b
commit ac9ec4e75a
16 changed files with 43 additions and 37 deletions

View File

@ -132,7 +132,7 @@ class BoardImagesService(BoardImagesServiceABC):
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int
board_record: BoardRecord, cover_image_name: Union[str, None], image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from typing import Any, Union
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
@ -28,7 +28,7 @@ class EventServiceBase:
graph_execution_state_id: str,
node: dict,
source_node_id: str,
progress_image: ProgressImage | None,
progress_image: Union[ProgressImage, None],
step: int,
total_steps: int,
) -> None:

View File

@ -3,7 +3,6 @@
import copy
import itertools
import uuid
from types import NoneType
from typing import (
Annotated,
Any,
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
InvocationContext,
)
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection")
@ -846,7 +847,7 @@ class GraphExecutionState(BaseModel):
]
}
def next(self) -> BaseInvocation | None:
def next(self) -> Union[BaseInvocation, None]:
"""Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional
from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
@ -80,7 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str | Path):
def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return path
def validate_path(self, path: str | Path) -> bool:
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()
@ -175,7 +175,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None:
def __get_cache(self, image_name: Path) -> Union[PILImageType, None]:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType):

View File

@ -116,7 +116,7 @@ class ImageRecordStorageBase(ABC):
pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
def get_most_recent_image_for_board(self, board_id: str) -> Union[ImageRecord, None]:
"""Gets the most recent image for a board."""
pass

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from queue import Queue
from pydantic import BaseModel, Field
from typing import Union
class InvocationQueueItem(BaseModel):
@ -22,7 +23,7 @@ class InvocationQueueABC(ABC):
pass
@abstractmethod
def put(self, item: InvocationQueueItem | None) -> None:
def put(self, item: Union[InvocationQueueItem, None]) -> None:
pass
@abstractmethod
@ -57,7 +58,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
return item
def put(self, item: InvocationQueueItem | None) -> None:
def put(self, item: Union[InvocationQueueItem, None]) -> None:
self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:

View File

@ -2,6 +2,7 @@
from abc import ABC
from threading import Event, Thread
from typing import Union
from ..invocations.baseinvocation import InvocationContext
from .graph import Graph, GraphExecutionState
@ -21,7 +22,7 @@ class Invoker:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
) -> Union[str, None]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -45,7 +46,7 @@ class Invoker:
return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
from typing import Dict, Union
import torch
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
def __get_cache(self, name: str) -> Union[torch.Tensor, None]:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str | Path
__output_folder: Union[str, Path]
def __init__(self, output_folder: str | Path):
def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
@ -91,4 +91,4 @@ class DiskLatentsStorage(LatentsStorageBase):
def get_path(self, name: str) -> Path:
return self.__output_folder / name