ruff fixes and restore default map location of object serializer load

This commit is contained in:
Lincoln Stein 2024-07-18 15:07:09 -04:00
parent 02957be333
commit 9dcace7d82
9 changed files with 9 additions and 25 deletions

View File

@ -1,7 +1,6 @@
from typing import Iterator, List, Optional, Tuple, Union, cast from typing import Iterator, List, Optional, Tuple, Union, cast
import torch import torch
import threading
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@ -140,7 +139,6 @@ class SDXLPromptInvocationBase:
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tid = threading.current_thread().ident
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_info = context.models.load(clip_field.text_encoder)

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy import copy
import inspect import inspect
import threading
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -194,7 +193,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Get the text embeddings and masks from the input conditioning fields.""" """Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = []
tid = threading.current_thread().ident
for cond in cond_list: for cond in cond_list:
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name)) cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
@ -319,7 +317,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(uncond_list, list): if not isinstance(uncond_list, list):
uncond_list = [uncond_list] uncond_list = [uncond_list]
tid = threading.current_thread().ident
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype cond_list, context, unet.device, unet.dtype
) )

View File

@ -47,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def load(self, name: str) -> T: def load(self, name: str) -> T:
file_path = self._get_path(name) file_path = self._get_path(name)
try: try:
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType] return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e: except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e raise ObjectNotFoundError(name) from e

View File

@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list # Input type must be a list
if get_origin(input_field) != list: if get_origin(input_field) is not list:
return False return False
# Validate that all outputs match the input type # Validate that all outputs match the input type

View File

@ -20,25 +20,21 @@ context. Use like this:
import copy import copy
import gc import gc
import math
import sys import sys
import threading import threading
import time
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from logging import Logger from logging import Logger
from threading import BoundedSemaphore from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set from typing import Dict, Generator, List, Optional, Set
import torch import torch
from diffusers.configuration_utils import ConfigMixin
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..optimizations import skip_torch_weight_init
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker from .model_locker import ModelLocker
@ -222,7 +218,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model) size = calc_model_size_by_data(model)
self.make_room(size) self.make_room(size)
tid = threading.current_thread().ident
cache_record = CacheRecord(key=key, model=model, size=size) cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)

View File

@ -57,4 +57,3 @@ class ModelLocker(ModelLockerBase):
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model.""" """Return the state dict (if any) for the cached model."""
return None return None

View File

@ -2,8 +2,9 @@
"""These classes implement model patching with LoRAs and Textual Inversions.""" """These classes implement model patching with LoRAs and Textual Inversions."""
from __future__ import annotations from __future__ import annotations
import threading
import pickle import pickle
import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@ -34,7 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method? # TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher: class ModelPatcher:
_thread_lock = threading.Lock() _thread_lock = threading.Lock()
@staticmethod @staticmethod
@ -109,10 +109,7 @@ class ModelPatcher:
""" """
original_weights = {} original_weights = {}
try: try:
with ( with torch.no_grad(), cls._thread_lock:
torch.no_grad(),
cls._thread_lock
):
for lora, lora_weight in loras: for lora, lora_weight in loras:
# assert lora.device.type == "cpu" # assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():

View File

@ -1,5 +1,4 @@
import math import math
import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
@ -32,7 +31,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
add_time_ids: torch.Tensor add_time_ids: torch.Tensor
def to(self, device, dtype=None): def to(self, device, dtype=None):
tid = threading.current_thread().ident
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)

View File

@ -296,8 +296,8 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
try: try:
if conditioning_data.is_sdxl(): if conditioning_data.is_sdxl():
#tid = threading.current_thread().ident # tid = threading.current_thread().ident
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True), # print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
@ -317,7 +317,7 @@ class InvokeAIDiffuserComponent:
} }
except Exception as e: except Exception as e:
tid = threading.current_thread().ident tid = threading.current_thread().ident
print(f'DEBUG: {tid} {str(e)}') print(f"DEBUG: {tid} {str(e)}")
raise e raise e
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: