xosski

Ai image generator

Dec 4th, 2024
18
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.18 KB | None | 0 0
  1. import argparse
  2. import os
  3. import math
  4. import numpy as np
  5. import torch
  6. import safetensors.torch as sf
  7. import db_examples
  8.  
  9. from PIL import Image
  10. from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
  11. from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
  12. from diffusers.models.attention_processor import AttnProcessor2_0
  13. from transformers import CLIPTextModel, CLIPTokenizer
  14. from briarmbg import BriaRMBG
  15. from enum import Enum
  16. from torch.hub import download_url_to_file
  17.  
  18. # 'stablediffusionapi/realistic-vision-v51'
  19. # 'runwayml/stable-diffusion-v1-5'
  20. sd15_name = 'stablediffusionapi/realistic-vision-v51'
  21. tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
  22. text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
  23. vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
  24. unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
  25. rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
  26.  
  27. # Change UNet
  28.  
  29. with torch.no_grad():
  30. new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
  31. new_conv_in.weight.zero_()
  32. new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
  33. new_conv_in.bias = unet.conv_in.bias
  34. unet.conv_in = new_conv_in
  35.  
  36. unet_original_forward = unet.forward
  37.  
  38.  
  39. def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
  40. c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
  41. c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
  42. new_sample = torch.cat([sample, c_concat], dim=1)
  43. kwargs['cross_attention_kwargs'] = {}
  44. return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
  45.  
  46.  
  47. unet.forward = hooked_unet_forward
  48.  
  49. # Load
  50.  
  51. model_path = './models/iclight_sd15_fc.safetensors'
  52.  
  53. if not os.path.exists(model_path):
  54. download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
  55.  
  56. sd_offset = sf.load_file(model_path)
  57. sd_origin = unet.state_dict()
  58. keys = sd_origin.keys()
  59. sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
  60. unet.load_state_dict(sd_merged, strict=True)
  61. del sd_offset, sd_origin, sd_merged, keys
  62.  
  63. # Device
  64.  
  65. device = torch.device('cuda')
  66. text_encoder = text_encoder.to(device=device, dtype=torch.float16)
  67. vae = vae.to(device=device, dtype=torch.bfloat16)
  68. unet = unet.to(device=device, dtype=torch.float16)
  69. rmbg = rmbg.to(device=device, dtype=torch.float32)
  70.  
  71. # SDP
  72.  
  73. unet.set_attn_processor(AttnProcessor2_0())
  74. vae.set_attn_processor(AttnProcessor2_0())
  75.  
  76. # Samplers
  77.  
  78. ddim_scheduler = DDIMScheduler(
  79. num_train_timesteps=1000,
  80. beta_start=0.00085,
  81. beta_end=0.012,
  82. beta_schedule="scaled_linear",
  83. clip_sample=False,
  84. set_alpha_to_one=False,
  85. steps_offset=1,
  86. )
  87.  
  88. euler_a_scheduler = EulerAncestralDiscreteScheduler(
  89. num_train_timesteps=1000,
  90. beta_start=0.00085,
  91. beta_end=0.012,
  92. steps_offset=1
  93. )
  94.  
  95. dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
  96. num_train_timesteps=1000,
  97. beta_start=0.00085,
  98. beta_end=0.012,
  99. algorithm_type="sde-dpmsolver++",
  100. use_karras_sigmas=True,
  101. steps_offset=1
  102. )
  103.  
  104. # Pipelines
  105.  
  106. t2i_pipe = StableDiffusionPipeline(
  107. vae=vae,
  108. text_encoder=text_encoder,
  109. tokenizer=tokenizer,
  110. unet=unet,
  111. scheduler=dpmpp_2m_sde_karras_scheduler,
  112. safety_checker=None,
  113. requires_safety_checker=False,
  114. feature_extractor=None,
  115. image_encoder=None
  116. )
  117.  
  118. i2i_pipe = StableDiffusionImg2ImgPipeline(
  119. vae=vae,
  120. text_encoder=text_encoder,
  121. tokenizer=tokenizer,
  122. unet=unet,
  123. scheduler=dpmpp_2m_sde_karras_scheduler,
  124. safety_checker=None,
  125. requires_safety_checker=False,
  126. feature_extractor=None,
  127. image_encoder=None
  128. )
  129.  
  130.  
  131. @torch.inference_mode()
  132. def encode_prompt_inner(txt: str):
  133. max_length = tokenizer.model_max_length
  134. chunk_length = tokenizer.model_max_length - 2
  135. id_start = tokenizer.bos_token_id
  136. id_end = tokenizer.eos_token_id
  137. id_pad = id_end
  138.  
  139. def pad(x, p, i):
  140. return x[:i] if len(x) >= i else x + [p] * (i - len(x))
  141.  
  142. tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
  143. chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
  144. chunks = [pad(ck, id_pad, max_length) for ck in chunks]
  145.  
  146. token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
  147. conds = text_encoder(token_ids).last_hidden_state
  148.  
  149. return conds
  150.  
  151.  
  152. @torch.inference_mode()
  153. def encode_prompt_pair(positive_prompt, negative_prompt):
  154. c = encode_prompt_inner(positive_prompt)
  155. uc = encode_prompt_inner(negative_prompt)
  156.  
  157. c_len = float(len(c))
  158. uc_len = float(len(uc))
  159. max_count = max(c_len, uc_len)
  160. c_repeat = int(math.ceil(max_count / c_len))
  161. uc_repeat = int(math.ceil(max_count / uc_len))
  162. max_chunk = max(len(c), len(uc))
  163.  
  164. c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
  165. uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
  166.  
  167. c = torch.cat([p[None, ...] for p in c], dim=1)
  168. uc = torch.cat([p[None, ...] for p in uc], dim=1)
  169.  
  170. return c, uc
  171.  
  172.  
  173. @torch.inference_mode()
  174. def pytorch2numpy(imgs, quant=True):
  175. results = []
  176. for x in imgs:
  177. y = x.movedim(0, -1)
  178.  
  179. if quant:
  180. y = y * 127.5 + 127.5
  181. y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
  182. else:
  183. y = y * 0.5 + 0.5
  184. y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
  185.  
  186. results.append(y)
  187. return results
  188.  
  189.  
  190. @torch.inference_mode()
  191. def numpy2pytorch(imgs):
  192. h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
  193. h = h.movedim(-1, 1)
  194. return h
  195.  
  196.  
  197. def resize_and_center_crop(image, target_width, target_height):
  198. pil_image = Image.fromarray(image)
  199. original_width, original_height = pil_image.size
  200. scale_factor = max(target_width / original_width, target_height / original_height)
  201. resized_width = int(round(original_width * scale_factor))
  202. resized_height = int(round(original_height * scale_factor))
  203. resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
  204. left = (resized_width - target_width) / 2
  205. top = (resized_height - target_height) / 2
  206. right = (resized_width + target_width) / 2
  207. bottom = (resized_height + target_height) / 2
  208. cropped_image = resized_image.crop((left, top, right, bottom))
  209. return np.array(cropped_image)
  210.  
  211.  
  212. def resize_without_crop(image, target_width, target_height):
  213. pil_image = Image.fromarray(image)
  214. resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
  215. return np.array(resized_image)
  216.  
  217.  
  218. @torch.inference_mode()
  219. def run_rmbg(img, sigma=0.0):
  220. H, W, C = img.shape
  221. assert C == 3
  222. k = (256.0 / float(H * W)) ** 0.5
  223. feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
  224. feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
  225. alpha = rmbg(feed)[0][0]
  226. alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
  227. alpha = alpha.movedim(1, -1)[0]
  228. alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
  229. result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
  230. return result.clip(0, 255).astype(np.uint8), alpha
  231.  
  232.  
  233. @torch.inference_mode()
  234. def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
  235. input_bg = None
  236.  
  237. if bg_source == "NONE":
  238. pass
  239. elif bg_source == "LEFT":
  240. gradient = np.linspace(255, 0, image_width)
  241. image = np.tile(gradient, (image_height, 1))
  242. input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
  243. elif bg_source == "RIGHT":
  244. gradient = np.linspace(0, 255, image_width)
  245. image = np.tile(gradient, (image_height, 1))
  246. input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
  247. elif bg_source == "TOP":
  248. gradient = np.linspace(255, 0, image_height)[:, None]
  249. image = np.tile(gradient, (1, image_width))
  250. input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
  251. elif bg_source == "BOTTOM":
  252. gradient = np.linspace(0, 255, image_height)[:, None]
  253. image = np.tile(gradient, (1, image_width))
  254. input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
  255. else:
  256. raise 'Wrong initial latent!'
  257.  
  258. rng = torch.Generator(device=device).manual_seed(int(seed))
  259.  
  260. fg = resize_and_center_crop(input_fg, image_width, image_height)
  261.  
  262. concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
  263. concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
  264.  
  265. conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
  266.  
  267. if input_bg is None:
  268. latents = t2i_pipe(
  269. prompt_embeds=conds,
  270. negative_prompt_embeds=unconds,
  271. width=image_width,
  272. height=image_height,
  273. num_inference_steps=steps,
  274. num_images_per_prompt=num_samples,
  275. generator=rng,
  276. output_type='latent',
  277. guidance_scale=cfg,
  278. cross_attention_kwargs={'concat_conds': concat_conds},
  279. ).images.to(vae.dtype) / vae.config.scaling_factor
  280. else:
  281. bg = resize_and_center_crop(input_bg, image_width, image_height)
  282. bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
  283. bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
  284. latents = i2i_pipe(
  285. image=bg_latent,
  286. strength=lowres_denoise,
  287. prompt_embeds=conds,
  288. negative_prompt_embeds=unconds,
  289. width=image_width,
  290. height=image_height,
  291. num_inference_steps=int(round(steps / lowres_denoise)),
  292. num_images_per_prompt=num_samples,
  293. generator=rng,
  294. output_type='latent',
  295. guidance_scale=cfg,
  296. cross_attention_kwargs={'concat_conds': concat_conds},
  297. ).images.to(vae.dtype) / vae.config.scaling_factor
  298.  
  299. pixels = vae.decode(latents).sample
  300. pixels = pytorch2numpy(pixels)
  301. pixels = [resize_without_crop(
  302. image=p,
  303. target_width=int(round(image_width * highres_scale / 64.0) * 64),
  304. target_height=int(round(image_height * highres_scale / 64.0) * 64))
  305. for p in pixels]
  306.  
  307. pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
  308. latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
  309. latents = latents.to(device=unet.device, dtype=unet.dtype)
  310.  
  311. image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
  312.  
  313. fg = resize_and_center_crop(input_fg, image_width, image_height)
  314. concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
  315. concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
  316.  
  317. latents = i2i_pipe(
  318. image=latents,
  319. strength=highres_denoise,
  320. prompt_embeds=conds,
  321. negative_prompt_embeds=unconds,
  322. width=image_width,
  323. height=image_height,
  324. num_inference_steps=int(round(steps / highres_denoise)),
  325. num_images_per_prompt=num_samples,
  326. generator=rng,
  327. output_type='latent',
  328. guidance_scale=cfg,
  329. cross_attention_kwargs={'concat_conds': concat_conds},
  330. ).images.to(vae.dtype) / vae.config.scaling_factor
  331.  
  332. pixels = vae.decode(latents).sample
  333.  
  334. return pytorch2numpy(pixels)
  335.  
  336.  
  337. @torch.inference_mode()
  338. def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
  339. input_fg, matting = run_rmbg(input_fg)
  340. results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
  341. return input_fg, results
  342.  
  343.  
  344. # quick_prompts = [
  345. # 'sunshine from window',
  346. # 'neon light, city',
  347. # 'sunset over sea',
  348. # 'golden time',
  349. # 'sci-fi RGB glowing, cyberpunk',
  350. # 'natural lighting',
  351. # 'warm atmosphere, at home, bedroom',
  352. # 'magic lit',
  353. # 'evil, gothic, Yharnam',
  354. # 'light and shadow',
  355. # 'shadow from window',
  356. # 'soft studio lighting',
  357. # 'home atmosphere, cozy bedroom illumination',
  358. # 'neon, Wong Kar-wai, warm'
  359. # ]
  360. # quick_prompts = [[x] for x in quick_prompts]
  361.  
  362.  
  363. # quick_subjects = [
  364. # 'beautiful woman, detailed face',
  365. # 'handsome man, detailed face',
  366. # ]
  367. # quick_subjects = [[x] for x in quick_subjects]
  368.  
  369. def parse_args():
  370. parser = argparse.ArgumentParser(description="Process image generation parameters")
  371. parser.add_argument('--input_fg', type=str, required=True, help="Path to the foreground image")
  372. parser.add_argument('--prompt', type=str, required=True, help="Text prompt for the image generation")
  373. parser.add_argument('--image_width', type=int, default=970, help="Width of the generated image")
  374. parser.add_argument('--image_height', type=int, default=600, help="Height of the generated image")
  375. parser.add_argument('--num_samples', type=int, default=1, help="Number of images to generate")
  376. parser.add_argument('--seed', type=int, default=12345, help="Random seed for generation")
  377. parser.add_argument('--steps', type=int, default=25, help="Number of inference steps")
  378. parser.add_argument('--a_prompt', type=str, default="best quality", help="Additional prompt")
  379. parser.add_argument('--n_prompt', type=str, default="lowres, bad anatomy", help="Negative prompt")
  380. parser.add_argument('--cfg', type=float, default=7.5, help="CFG scale")
  381. parser.add_argument('--highres_scale', type=float, default=1.5, help="Highres scale")
  382. parser.add_argument('--highres_denoise', type=float, default=0.5, help="Highres denoise")
  383. parser.add_argument('--lowres_denoise', type=float, default=0.9, help="Lowres denoise")
  384. parser.add_argument('--bg_source', type=str, default="NONE", choices=["NONE", "LEFT", "RIGHT", "TOP", "BOTTOM"], help="Background source preference")
  385.  
  386. return parser.parse_args()
  387.  
  388. if __name__ == "__main__":
  389. args = parse_args()
  390.  
  391. input_fg = np.array(Image.open(args.input_fg).convert("RGB"))
  392.  
  393. result = process(
  394. input_fg=input_fg,
  395. prompt=args.prompt,
  396. image_width=args.image_width,
  397. image_height=args.image_height,
  398. num_samples=args.num_samples,
  399. seed=args.seed,
  400. steps=args.steps,
  401. a_prompt=args.a_prompt,
  402. n_prompt=args.n_prompt,
  403. cfg=args.cfg,
  404. highres_scale=args.highres_scale,
  405. highres_denoise=args.highres_denoise,
  406. lowres_denoise=args.lowres_denoise,
  407. bg_source=args.bg_source
  408. )
  409.  
  410. output_image = Image.fromarray(result[0])
  411. output_image.save("generated_image.png")
  412.  
Add Comment
Please, Sign In to add comment