mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_cache: let offload_model work with DiffusionPipeline, sorta.
This commit is contained in:
parent
95db6e80ee
commit
01ff1cff58
@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
'''
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -413,10 +414,13 @@ class ModelCache(object):
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
try:
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
except AttributeError as e:
|
||||
warnings.warn(f"TODO: clean up legacy model-management: {e}")
|
||||
return model.to('cpu')
|
||||
else:
|
||||
return model
|
||||
|
@ -1,5 +1,7 @@
|
||||
import math
|
||||
import os.path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -235,13 +237,15 @@ class SpatialRescaler(nn.Module):
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: CLIPTextModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
max_length=77,
|
||||
tokenizer=None,
|
||||
transformer=None,
|
||||
version:str='openai/clip-vit-large-patch14',
|
||||
max_length:int=77,
|
||||
tokenizer:Optional[CLIPTokenizer]=None,
|
||||
transformer:Optional[CLIPTextModel]=None,
|
||||
):
|
||||
super().__init__()
|
||||
cache = os.path.join(Globals.root,'models',version)
|
||||
@ -464,6 +468,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
@device.setter
|
||||
def device(self, device):
|
||||
self.transformer.to(device=device)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
|
Loading…
Reference in New Issue
Block a user