repair medvram and lowvram
This commit is contained in:
parent
abb948dab0
commit
9a3f35b028
@ -100,7 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
if sd_model.embedder:
|
if sd_model.embedder:
|
||||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
|
||||||
|
if hasattr(sd_model, 'cond_stage_model'):
|
||||||
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
if use_medvram:
|
if use_medvram:
|
||||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
|||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
ids = tokenizer.encode(init_text)
|
ids = tokenizer.encode(init_text)
|
||||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||||
|
|
||||||
return embedded
|
return embedded
|
||||||
|
|
||||||
@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi
|
|||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
ids = tokenizer.encode(init_text)
|
ids = tokenizer.encode(init_text)
|
||||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||||
|
|
||||||
return embedded
|
return embedded
|
||||||
|
Loading…
Reference in New Issue
Block a user