potential fix for embeddings no loading on AMD cards

This commit is contained in:
AUTOMATIC 2022-09-25 15:04:39 +03:00
parent 615b2fc9ce
commit 073f6eac22

View File

@ -201,7 +201,7 @@ class StableDiffusionModelHijack:
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = torch.load(path) data = torch.load(path, map_location="cpu")
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
@ -217,7 +217,7 @@ class StableDiffusionModelHijack:
if len(emb.shape) == 1: if len(emb.shape) == 1:
emb = emb.unsqueeze(0) emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach() self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]