mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_manager: model to/from CPU methods are implemented on the Pipeline
This commit is contained in:
parent
42355b70c2
commit
ad7b1fa6fb
@ -104,7 +104,7 @@ class ModelManager(object):
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||
self.models[model_name]["model"] = self._model_from_cpu(requested_model)
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
hash = self.models[model_name]["hash"]
|
||||
@ -499,6 +499,7 @@ class ModelManager(object):
|
||||
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.models[model_name]["model"] = self._model_to_cpu(model)
|
||||
|
||||
gc.collect()
|
||||
@ -1044,43 +1045,6 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _model_to_cpu(self, model):
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.offload_all()
|
||||
return model
|
||||
|
||||
model.cond_stage_model.device = CPU_DEVICE
|
||||
model.to(CPU_DEVICE)
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to(CPU_DEVICE)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self, model):
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.ready()
|
||||
return model
|
||||
|
||||
model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to(self.device)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
"""
|
||||
Remove the first element of the FIFO, which ought
|
||||
|
Loading…
Reference in New Issue
Block a user