mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
38343917f8
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
25 lines
797 B
Python
25 lines
797 B
Python
"""Base class for 'Raw' models.
|
|
|
|
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
|
and is used for type checking of calls to the model patcher. Its main purpose
|
|
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
|
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
|
from lora.py.
|
|
|
|
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
|
that adds additional methods and attributes.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
class RawModel(ABC):
|
|
"""Abstract base class for 'Raw' model wrappers."""
|
|
|
|
@abstractmethod
|
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
pass
|