From bb91ca0462f12ebd0b9669be0d45e26a21ac5ecb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 21 Aug 2022 17:09:00 -0400 Subject: [PATCH] first attempt to fold k_lms changes proposed by hwharrison and bmaltais --- ldm/models/diffusion/ksampler.py | 64 ++++++++++++++++++++++++++++++++ ldm/simplet2i.py | 12 ++++-- 2 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 ldm/models/diffusion/ksampler.py diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py new file mode 100644 index 0000000000..39c0fdf542 --- /dev/null +++ b/ldm/models/diffusion/ksampler.py @@ -0,0 +1,64 @@ +'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers''' +import k_diffusion as K +import torch.nn as nn + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + +class KSampler(object): + def __init__(self,model,schedule="lms", **kwargs): + super().__init__() + self.model = K.external.CompVisDenoiser(model) + self.accelerator = accelerate.Accelerator() + self.device = accelerator.device + self.schedule = schedule + + # most of these arguments are ignored and are only present for compatibility with + # other samples + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + + sigmas = self.model.get_sigmas(S) + if x_T: + x = x_T + else: + x = torch.randn([batch_size, *shape], device=device) * sigmas[0] # for GPU draw + model_wrap_cfg = CFGDenoiser(self.model) + extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} + return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process), + None) + + def gather(samples_ddim): + return self.accelerator.gather(samples_ddim) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 796a99396b..6f740d1f83 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -11,7 +11,7 @@ t2i = T2I(outdir = // outputs/txt2img-samples batch_size = // how many images to generate per sampling (1) steps = // 50 seed = // current system time - sampler = ['ddim','plms'] // ddim + sampler = ['ddim','plms','klms'] // klms grid = // false width = // image width, multiple of 64 (512) height = // image height, multiple of 64 (512) @@ -62,8 +62,9 @@ import time import math from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ksampler import KSampler class T2I: """T2I class @@ -101,7 +102,7 @@ class T2I: cfg_scale=7.5, weights="models/ldm/stable-diffusion-v1/model.ckpt", config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml", - sampler="plms", + sampler="klms", latent_channels=4, downsampling_factor=8, ddim_eta=0.0, # deterministic @@ -387,6 +388,9 @@ class T2I: elif self.sampler_name == 'ddim': print("setting sampler to ddim") self.sampler = DDIMSampler(self.model) + elif self.sampler_name == 'klms': + print("setting sampler to klms") + self.sampler = KSampler(self.model,'lms') else: print(f"unsupported sampler {self.sampler_name}, defaulting to plms") self.sampler = PLMSSampler(self.model)