mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply lora by patching lora instead of hooks
This commit is contained in:
parent
1ba94a92b3
commit
5cebf67ee4
@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer,\
|
with tokenizer_info as orig_tokenizer,\
|
||||||
text_encoder_info as text_encoder,\
|
text_encoder_info as text_encoder:
|
||||||
ExitStack() as stack:
|
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
stack.enter_context(
|
context.services.model_manager.get_model(
|
||||||
context.services.model_manager.get_model(
|
model_name=name,
|
||||||
model_name=name,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
model_type=ModelType.TextualInversion,
|
||||||
model_type=ModelType.TextualInversion,
|
).context.model
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
#print(e)
|
#print(e)
|
||||||
|
@ -362,8 +362,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
with unet_info as unet,\
|
with unet_info as unet:
|
||||||
ExitStack() as stack:
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -374,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline, context=context, control_input=self.control,
|
||||||
@ -438,8 +437,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with unet_info as unet,\
|
with unet_info as unet:
|
||||||
ExitStack() as stack:
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -468,7 +466,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
device=unet.device,
|
device=unet.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
|
@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
|
|
||||||
|
# TODO: ui rewrite
|
||||||
|
base_model = BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||||
|
|
||||||
@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.unet = copy.deepcopy(self.unet)
|
output.unet = copy.deepcopy(self.unet)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip = copy.deepcopy(self.clip)
|
output.clip = copy.deepcopy(self.clip)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
|
@ -70,7 +70,7 @@ class LoRALayerBase:
|
|||||||
op = torch.nn.functional.linear
|
op = torch.nn.functional.linear
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
|
|
||||||
weight = self.get_weight(module)
|
weight = self.get_weight()
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
bias = self.bias if self.bias is not None else 0
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
@ -81,7 +81,7 @@ class LoRALayerBase:
|
|||||||
**extra_args,
|
**extra_args,
|
||||||
) * multiplier * scale
|
) * multiplier * scale
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||||
@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
w1 = self.w1
|
w1 = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
w1 = self.w1_a @ self.w1_b
|
w1 = self.w1_a @ self.w1_b
|
||||||
@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
w2 = w2.contiguous()
|
w2 = w2.contiguous()
|
||||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
@ -471,7 +471,7 @@ class ModelPatcher:
|
|||||||
submodule_name += "_" + key_parts.pop(0)
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
module = module.get_submodule(submodule_name)
|
||||||
module_key = module_key.rstrip(".")
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
return (module_key, module)
|
return (module_key, module)
|
||||||
|
|
||||||
@ -525,23 +525,36 @@ class ModelPatcher:
|
|||||||
loras: List[Tuple[LoraModel, float]],
|
loras: List[Tuple[LoraModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
hooks = dict()
|
original_weights = dict()
|
||||||
try:
|
try:
|
||||||
for lora, lora_weight in loras:
|
with torch.no_grad():
|
||||||
for layer_key, layer in lora.layers.items():
|
for lora, lora_weight in loras:
|
||||||
if not layer_key.startswith(prefix):
|
#assert lora.device.type == "cpu"
|
||||||
continue
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
if module_key not in hooks:
|
if module_key not in original_weights:
|
||||||
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
|
# enable autocast to calc fp16 loras on cpu
|
||||||
|
with torch.autocast(device_type="cpu"):
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||||
|
|
||||||
|
if module.weight.shape != layer_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
|
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
for module_key, hook in hooks.items():
|
with torch.no_grad():
|
||||||
hook.remove()
|
for module_key, weight in original_weights.items():
|
||||||
hooks.clear()
|
model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -591,7 +604,7 @@ class ModelPatcher:
|
|||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_embeddings.weight.data[token_id] = embedding
|
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
ti_tokens.append(token_id)
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
if len(ti_tokens) > 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user