mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply lora by model patching (#3583)
Rewrite lora to be applied by model patching as it gives us benefits: 1) On model execution calculates result only on model weight, while with hooks we need to calculate on model and each lora 2) As lora now patched in model weights, there no need to store lora in vram Results: Speed: | loras count | hook | patch | | --- | --- | --- | | 0 | ~4.92 it/s | ~4.92 it/s | | 1 | ~3.51 it/s | ~4.89 it/s | | 2 | ~2.76 it/s | ~4.92 it/s | VRAM: | loras count | hook | patch | | --- | --- | --- | | 0 | ~3.6 gb | ~3.6 gb | | 1 | ~4.0 gb | ~3.6 gb | | 2 | ~4.4 gb | ~3.7 gb | As based on #3547 wait to merge.
This commit is contained in:
commit
8a90e51408
@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation):
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
)
|
||||
)
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
|
@ -285,8 +285,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
with unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -297,7 +296,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
@ -361,8 +360,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
with unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -391,7 +389,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
device=unet.device,
|
||||
)
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
|
@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
|
||||
# TODO: ui rewrite
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||
|
||||
@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
|
@ -70,7 +70,7 @@ class LoRALayerBase:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight(module)
|
||||
weight = self.get_weight()
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
@ -81,7 +81,7 @@ class LoRALayerBase:
|
||||
**extra_args,
|
||||
) * multiplier * scale
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
@ -166,7 +166,7 @@ class LoHALayer(LoRALayerBase):
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@ -239,7 +239,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
@ -471,7 +471,7 @@ class ModelPatcher:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = module_key.rstrip(".")
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@ -525,23 +525,36 @@ class ModelPatcher:
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
hooks = dict()
|
||||
original_weights = dict()
|
||||
try:
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
#assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
if module_key not in hooks:
|
||||
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# enable autocast to calc fp16 loras on cpu
|
||||
with torch.autocast(device_type="cpu"):
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
for module_key, hook in hooks.items():
|
||||
hook.remove()
|
||||
hooks.clear()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
|
||||
@classmethod
|
||||
@ -591,7 +604,7 @@ class ModelPatcher:
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
model_embeddings.weight.data[token_id] = embedding
|
||||
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
|
Loading…
Reference in New Issue
Block a user