Apply lora by patching lora instead of hooks

This commit is contained in:
Sergey Borisov 2023-06-26 03:57:33 +03:00
parent 1ba94a92b3
commit 5cebf67ee4
4 changed files with 52 additions and 38 deletions

View File

@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation):
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer,\ with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\ text_encoder_info as text_encoder:
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_list.append(
stack.enter_context( context.services.model_manager.get_model(
context.services.model_manager.get_model( model_name=name,
model_name=name, base_model=self.clip.text_encoder.base_model,
base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion,
model_type=ModelType.TextualInversion, ).context.model
)
)
) )
except Exception: except Exception:
#print(e) #print(e)

View File

@ -362,8 +362,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\ with unet_info as unet:
ExitStack() as stack:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -374,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
@ -438,8 +437,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
**self.unet.unet.dict(), **self.unet.unet.dict(),
) )
with unet_info as unet,\ with unet_info as unet:
ExitStack() as stack:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -468,7 +466,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device, device=unet.device,
) )
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(

View File

@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
): ):
raise Exception(f"Unkown lora name: {self.lora_name}!") raise Exception(f"Unkown lora name: {self.lora_name}!")
@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet) output.unet = copy.deepcopy(self.unet)
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
) )
@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip) output.clip = copy.deepcopy(self.clip)
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
) )

View File

@ -70,7 +70,7 @@ class LoRALayerBase:
op = torch.nn.functional.linear op = torch.nn.functional.linear
extra_args = {} extra_args = {}
weight = self.get_weight(module) weight = self.get_weight()
bias = self.bias if self.bias is not None else 0 bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
@ -81,7 +81,7 @@ class LoRALayerBase:
**extra_args, **extra_args,
) * multiplier * scale ) * multiplier * scale
def get_weight(self, module: torch.nn.Module): def get_weight(self):
raise NotImplementedError() raise NotImplementedError()
def calc_size(self) -> int: def calc_size(self) -> int:
@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0] self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module): def get_weight(self):
if self.mid is not None: if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1]) up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(up.shape[0], up.shape[1])
@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0] self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module): def get_weight(self):
if self.t1 is None: if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase):
else: else:
self.rank = None # unscaled self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module): def get_weight(self):
w1 = self.w1 w1 = self.w1
if w1 is None: if w1 is None:
w1 = self.w1_a @ self.w1_b w1 = self.w1_a @ self.w1_b
@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4: if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2) w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous() w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape? weight = torch.kron(w1, w2)
return weight return weight
@ -471,7 +471,7 @@ class ModelPatcher:
submodule_name += "_" + key_parts.pop(0) submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name) module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".") module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module) return (module_key, module)
@ -525,23 +525,36 @@ class ModelPatcher:
loras: List[Tuple[LoraModel, float]], loras: List[Tuple[LoraModel, float]],
prefix: str, prefix: str,
): ):
hooks = dict() original_weights = dict()
try: try:
for lora, lora_weight in loras: with torch.no_grad():
for layer_key, layer in lora.layers.items(): for lora, lora_weight in loras:
if not layer_key.startswith(prefix): #assert lora.device.type == "cpu"
continue for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix) module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks: if module_key not in original_weights:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key)) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
yield # wait for context manager exit yield # wait for context manager exit
finally: finally:
for module_key, hook in hooks.items(): with torch.no_grad():
hook.remove() for module_key, weight in original_weights.items():
hooks.clear() model.get_submodule(module_key).weight.copy_(weight)
@classmethod @classmethod
@ -591,7 +604,7 @@ class ModelPatcher:
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
) )
model_embeddings.weight.data[token_id] = embedding model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
ti_tokens.append(token_id) ti_tokens.append(token_id)
if len(ti_tokens) > 1: if len(ti_tokens) > 1: