ruff fixes and restore default map location of object serializer load

This commit is contained in:
Lincoln Stein 2024-07-18 15:07:09 -04:00
parent 02957be333
commit 9dcace7d82
9 changed files with 9 additions and 25 deletions

View File

@ -1,7 +1,6 @@
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
import threading
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@ -140,7 +139,6 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tid = threading.current_thread().ident
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
import threading
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -194,7 +193,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
tid = threading.current_thread().ident
for cond in cond_list:
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
@ -319,7 +317,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
tid = threading.current_thread().ident
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)

View File

@ -47,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e

View File

@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) != list:
if get_origin(input_field) is not list:
return False
# Validate that all outputs match the input type

View File

@ -20,25 +20,21 @@ context. Use like this:
import copy
import gc
import math
import sys
import threading
import time
from contextlib import contextmanager, suppress
from logging import Logger
from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
import torch
from diffusers.configuration_utils import ConfigMixin
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..optimizations import skip_torch_weight_init
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
@ -222,7 +218,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model)
self.make_room(size)
tid = threading.current_thread().ident
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)

View File

@ -57,4 +57,3 @@ class ModelLocker(ModelLockerBase):
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None

View File

@ -2,8 +2,9 @@
"""These classes implement model patching with LoRAs and Textual Inversions."""
from __future__ import annotations
import threading
import pickle
import threading
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@ -34,7 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod
@ -109,10 +109,7 @@ class ModelPatcher:
"""
original_weights = {}
try:
with (
torch.no_grad(),
cls._thread_lock
):
with torch.no_grad(), cls._thread_lock:
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():

View File

@ -1,5 +1,4 @@
import math
import threading
from dataclasses import dataclass
from typing import List, Optional, Union
@ -32,7 +31,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
tid = threading.current_thread().ident
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)

View File

@ -296,8 +296,8 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
try:
if conditioning_data.is_sdxl():
#tid = threading.current_thread().ident
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
# tid = threading.current_thread().ident
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = {
"text_embeds": torch.cat(
[
@ -317,7 +317,7 @@ class InvokeAIDiffuserComponent:
}
except Exception as e:
tid = threading.current_thread().ident
print(f'DEBUG: {tid} {str(e)}')
print(f"DEBUG: {tid} {str(e)}")
raise e
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: