mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ruff fixes and restore default map location of object serializer load
This commit is contained in:
parent
02957be333
commit
9dcace7d82
@ -1,7 +1,6 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
import threading
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
@ -140,7 +139,6 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tid = threading.current_thread().ident
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import copy
|
||||
import inspect
|
||||
import threading
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@ -194,7 +193,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
tid = threading.current_thread().ident
|
||||
for cond in cond_list:
|
||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
@ -319,7 +317,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
tid = threading.current_thread().ident
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
@ -47,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
|
||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
except FileNotFoundError as e:
|
||||
raise ObjectNotFoundError(name) from e
|
||||
|
||||
|
@ -652,7 +652,7 @@ class Graph(BaseModel):
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
if get_origin(input_field) is not list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
|
@ -20,25 +20,21 @@ context. Use like this:
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager, suppress
|
||||
from logging import Logger
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..optimizations import skip_torch_weight_init
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
@ -222,7 +218,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
tid = threading.current_thread().ident
|
||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
@ -57,4 +57,3 @@ class ModelLocker(ModelLockerBase):
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return None
|
||||
|
||||
|
@ -2,8 +2,9 @@
|
||||
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
||||
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@ -34,7 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
|
||||
_thread_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
@ -109,10 +109,7 @@ class ModelPatcher:
|
||||
"""
|
||||
original_weights = {}
|
||||
try:
|
||||
with (
|
||||
torch.no_grad(),
|
||||
cls._thread_lock
|
||||
):
|
||||
with torch.no_grad(), cls._thread_lock:
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
|
@ -1,5 +1,4 @@
|
||||
import math
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -32,7 +31,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
tid = threading.current_thread().ident
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
assert self.pooled_embeds.device == device
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
|
@ -296,8 +296,8 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
try:
|
||||
if conditioning_data.is_sdxl():
|
||||
#tid = threading.current_thread().ident
|
||||
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
# tid = threading.current_thread().ident
|
||||
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
@ -317,7 +317,7 @@ class InvokeAIDiffuserComponent:
|
||||
}
|
||||
except Exception as e:
|
||||
tid = threading.current_thread().ident
|
||||
print(f'DEBUG: {tid} {str(e)}')
|
||||
print(f"DEBUG: {tid} {str(e)}")
|
||||
raise e
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user