mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ruff fixes and restore default map location of object serializer load
This commit is contained in:
parent
02957be333
commit
9dcace7d82
@ -1,7 +1,6 @@
|
|||||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import threading
|
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
@ -140,7 +139,6 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
tid = threading.current_thread().ident
|
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -194,7 +193,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
tid = threading.current_thread().ident
|
|
||||||
for cond in cond_list:
|
for cond in cond_list:
|
||||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
@ -319,7 +317,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not isinstance(uncond_list, list):
|
if not isinstance(uncond_list, list):
|
||||||
uncond_list = [uncond_list]
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
tid = threading.current_thread().ident
|
|
||||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
cond_list, context, unet.device, unet.dtype
|
cond_list, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
|
@ -47,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|||||||
def load(self, name: str) -> T:
|
def load(self, name: str) -> T:
|
||||||
file_path = self._get_path(name)
|
file_path = self._get_path(name)
|
||||||
try:
|
try:
|
||||||
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
|
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise ObjectNotFoundError(name) from e
|
raise ObjectNotFoundError(name) from e
|
||||||
|
|
||||||
|
@ -652,7 +652,7 @@ class Graph(BaseModel):
|
|||||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) is not list:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate that all outputs match the input type
|
# Validate that all outputs match the input type
|
||||||
|
@ -20,25 +20,21 @@ context. Use like this:
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import math
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore
|
from threading import BoundedSemaphore
|
||||||
from typing import Dict, Generator, List, Optional, Set
|
from typing import Dict, Generator, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..optimizations import skip_torch_weight_init
|
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
@ -222,7 +218,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
size = calc_model_size_by_data(model)
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
tid = threading.current_thread().ident
|
|
||||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
@ -57,4 +57,3 @@ class ModelLocker(ModelLockerBase):
|
|||||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
"""Return the state dict (if any) for the cached model."""
|
"""Return the state dict (if any) for the cached model."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import threading
|
|
||||||
import pickle
|
import pickle
|
||||||
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -34,7 +35,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
# TODO: rename smth like ModelPatcher and add TI method?
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
|
||||||
_thread_lock = threading.Lock()
|
_thread_lock = threading.Lock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -109,10 +109,7 @@ class ModelPatcher:
|
|||||||
"""
|
"""
|
||||||
original_weights = {}
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with (
|
with torch.no_grad(), cls._thread_lock:
|
||||||
torch.no_grad(),
|
|
||||||
cls._thread_lock
|
|
||||||
):
|
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
# assert lora.device.type == "cpu"
|
# assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
import threading
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@ -32,7 +31,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
def to(self, device, dtype=None):
|
||||||
tid = threading.current_thread().ident
|
|
||||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
assert self.pooled_embeds.device == device
|
assert self.pooled_embeds.device == device
|
||||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
|
@ -296,8 +296,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
try:
|
try:
|
||||||
if conditioning_data.is_sdxl():
|
if conditioning_data.is_sdxl():
|
||||||
#tid = threading.current_thread().ident
|
# tid = threading.current_thread().ident
|
||||||
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
@ -317,7 +317,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tid = threading.current_thread().ident
|
tid = threading.current_thread().ident
|
||||||
print(f'DEBUG: {tid} {str(e)}')
|
print(f"DEBUG: {tid} {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user