From 2736d7e15ebf9906f20588d619504b474ee221a7 Mon Sep 17 00:00:00 2001 From: xra Date: Mon, 22 Aug 2022 22:59:06 +0900 Subject: [PATCH] optional weighting for creative blending of prompts example: "an apple: a banana:0 a watermelon:0.5" the above example turns into 3 sub-prompts: "an apple" 1.0 (default if no value) "a banana" 0.0 "a watermelon" 0.5 The weights are added and normalized The resulting image will be: apple 66%, banana 0%, watermelon 33% --- ldm/simplet2i.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 62f4bea8d4..4c25939621 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -200,7 +200,21 @@ The vast majority of these arguments default to reasonable values. uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + # weighted sub-prompts + subprompts,weights = T2I.split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] / totalWeight + c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = model.get_learned_conditioning(prompts) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, conditioning=c, @@ -319,7 +333,20 @@ The vast majority of these arguments default to reasonable values. uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + + # weighted sub-prompts + subprompts,weights = T2I.split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] / totalWeight + c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = model.get_learned_conditioning(prompts) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) @@ -430,3 +457,53 @@ The vast majority of these arguments default to reasonable values. image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.*image - 1. + """ + example: "an apple: a banana:0 a watermelon:0.5" + grabs all text up to the first occurance of ':' + then removes the text, repeating until no characters left. + if ':' has no weight defined, defaults to 1.0 + + the above example turns into 3 sub-prompts: + "an apple" 1.0 + "a banana" 0.0 + "a watermelon" 0.5 + The weights are added and normalized + The resulting image will be: apple 66% (1.0 / 1.5), banana 0%, watermelon 33% (0.5 / 1.5) + """ + def split_weighted_subprompts(text): + # very simple, uses : to separate sub-prompts + # assumes number following : and space after number + # if no number found, defaults to 1.0 + remaining = len(text) + prompts = [] + weights = [] + while remaining > 0: + # find : + if ":" in text: + idx = text.index(":") # first occurrance from start + # snip sub prompt + prompt = text[:idx] + remaining -= idx + # remove from main text + text = text[idx+1:] + # get number + if " " in text: + idx = text.index(" ") # first occurance + else: # no space, read to end + idx = len(text) + if idx != 0: + weight = float(text[:idx]) + else: # no number to grab + weight = 1.0 + # remove + remaining -= idx + text = text[idx+1:] + prompts.append(prompt) + weights.append(weight) + else: + if len(text) > 0: + # take what remains as weight 1 + prompts.append(text) + weights.append(1.0) + remaining = 0 + return prompts, weights \ No newline at end of file