diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index d31963d4..f4ce4552 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): self.filename = filename self.filename_text = filename_text + self.weight = weight self.latent_dist = latent_dist self.latent_sample = latent_sample self.cond = cond @@ -56,10 +57,16 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): + alpha_channel = None if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB') + image = Image.open(path) + #Currently does not work for single color transparency + #We would need to read image.info['transparency'] for that + if 'A' in image.getbands(): + alpha_channel = image.getchannel('A') + image = image.convert('RGB') if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) except Exception: @@ -87,17 +94,33 @@ class PersonalizedBase(Dataset): with devices.autocast(): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) - elif latent_sampling_method == "deterministic": - # Works only for DiagonalGaussianDistribution - latent_dist.std = 0 - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) - elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + #Perform latent sampling, even for random sampling. + #We need the sample dimensions for the weights + if latent_sampling_method == "deterministic": + if isinstance(latent_dist, DiagonalGaussianDistribution): + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + else: + latent_sampling_method = "once" + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + + if alpha_channel is not None: + channels, *latent_size = latent_sample.shape + weight_img = alpha_channel.resize(latent_size) + npweight = np.array(weight_img).astype(np.float32) + #Repeat for every channel in the latent sample + weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) + #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. + weight -= weight.min() + weight /= weight.mean() + else: + #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later + weight = torch.ones([channels] + latent_size) + + if latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) + else: + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -110,6 +133,7 @@ class PersonalizedBase(Dataset): del torchdata del latent_dist del latent_sample + del weight self.length = len(self.dataset) self.groups = list(groups.values()) @@ -195,6 +219,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device)