Advertisement
kopyl

Untitled

Jan 29th, 2024 (edited)
805
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.54 KB | None | 0 0
  1. import functools
  2. import gc
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. from pathlib import Path
  9.  
  10. import accelerate
  11. import datasets
  12. import numpy as np
  13. import torch
  14. import torch.nn.functional as F
  15. import torch.utils.checkpoint
  16. import transformers
  17. from accelerate import Accelerator
  18. from accelerate.logging import get_logger
  19. from accelerate.utils import ProjectConfiguration, set_seed
  20. from accelerate.utils import InitProcessGroupKwargs
  21. from datasets import load_dataset, DatasetDict, load_from_disk
  22. from huggingface_hub import create_repo, upload_folder
  23. from packaging import version
  24. from torchvision import transforms
  25. from torchvision.transforms.functional import crop
  26. from tqdm.auto import tqdm
  27. from transformers import AutoTokenizer, PretrainedConfig
  28.  
  29. import diffusers
  30. from diffusers import (
  31.     AutoencoderKL,
  32.     DDPMScheduler,
  33.     StableDiffusionXLPipeline,
  34.     UNet2DConditionModel,
  35. )
  36. from diffusers.optimization import get_scheduler
  37. from diffusers.training_utils import EMAModel, compute_snr
  38. from diffusers.utils import check_min_version, is_wandb_available
  39. from diffusers.utils.import_utils import is_xformers_available
  40. from datetime import timedelta
  41. from multiprocess import set_start_method
  42.  
  43.  
  44. resolution = 128
  45. dataset_name = "kopyl/3M_icons_monochrome_only_no_captioning"
  46. cache_dir = "/workspace/dataset-cache"
  47. pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
  48. pretrained_vae_model_name_or_path = "madebyollin/sdxl-vae-fp16-fix"
  49.  
  50. max_train_samples = None
  51.  
  52.  
  53. check_min_version("0.24.0.dev0")
  54. logger = get_logger(__name__)
  55. DATASET_NAME_MAPPING = {
  56.     "lambdalabs/pokemon-blip-captions": ("image", "text"),
  57. }
  58.  
  59. dataset = load_dataset(
  60.     dataset_name,
  61.     cache_dir=cache_dir,
  62.     token="",
  63.     num_proc=50
  64. )
  65.  
  66.  
  67. if max_train_samples is not None:
  68.     dataset["train"] = dataset["train"].select(range(max_train_samples))
  69.  
  70.  
  71. column_names = dataset["train"].column_names
  72.  
  73.  
  74. train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)
  75. train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  76. train_crop = transforms.RandomCrop(resolution)
  77.  
  78.  
  79. def preprocess_train(examples):
  80.     image_column = column_names[0]
  81.     images = [image.convert("RGB") for image in examples[image_column]]
  82.     # image aug
  83.     original_sizes = []
  84.     all_images = []
  85.     crop_top_lefts = []
  86.     for image in images:
  87.         original_sizes.append((image.height, image.width))
  88.         image = train_resize(image)
  89.  
  90.         y1, x1, h, w = train_crop.get_params(image, (resolution, resolution))
  91.         image = crop(image, y1, x1, h, w)
  92.  
  93.         crop_top_left = (y1, x1)
  94.         crop_top_lefts.append(crop_top_left)
  95.         image = train_transforms(image)
  96.         all_images.append(image)
  97.  
  98.     examples["original_sizes"] = original_sizes
  99.     examples["crop_top_lefts"] = crop_top_lefts
  100.     examples["pixel_values"] = all_images
  101.     return examples
  102.  
  103.  
  104. train_dataset = dataset["train"].with_transform(preprocess_train)
  105. print(type(train_dataset))
  106. train_dataset
  107.  
  108.  
  109. def import_model_class_from_model_name_or_path(
  110.     pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
  111. ):
  112.     text_encoder_config = PretrainedConfig.from_pretrained(
  113.         pretrained_model_name_or_path, subfolder=subfolder, revision=revision, device="cuda"
  114.     )
  115.     model_class = text_encoder_config.architectures[0]
  116.  
  117.     if model_class == "CLIPTextModel":
  118.         from transformers import CLIPTextModel
  119.  
  120.         return CLIPTextModel
  121.     elif model_class == "CLIPTextModelWithProjection":
  122.         from transformers import CLIPTextModelWithProjection
  123.  
  124.         return CLIPTextModelWithProjection
  125.     else:
  126.         raise ValueError(f"{model_class} is not supported.")
  127.  
  128.  
  129. text_encoder_cls_one = import_model_class_from_model_name_or_path(
  130.     pretrained_model_name_or_path, None
  131. )
  132. text_encoder_cls_two = import_model_class_from_model_name_or_path(
  133.     pretrained_model_name_or_path, None, subfolder="text_encoder_2"
  134. )
  135.  
  136. text_encoder_one = text_encoder_cls_one.from_pretrained(
  137.     pretrained_model_name_or_path, subfolder="text_encoder"
  138. )
  139. text_encoder_two = text_encoder_cls_two.from_pretrained(
  140.     pretrained_model_name_or_path, subfolder="text_encoder_2"
  141. )
  142. vae = AutoencoderKL.from_pretrained(
  143.     pretrained_model_name_or_path, subfolder="vae", device="cuda"
  144. )
  145.  
  146. tokenizer_one = AutoTokenizer.from_pretrained(
  147.     pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False, device="cuda"
  148. )
  149. tokenizer_two = AutoTokenizer.from_pretrained(
  150.     pretrained_model_name_or_path, subfolder="tokenizer_2", device="cuda"
  151. )
  152.  
  153.  
  154. text_encoders = [text_encoder_one, text_encoder_two]
  155. tokenizers = [tokenizer_one, tokenizer_two]
  156.  
  157.  
  158. def encode_prompt(batch, rank, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
  159.     for text_encoder in text_encoders:
  160.         text_encoder.to(f"cuda:{rank}")
  161.    
  162.     prompt_embeds_list = []
  163.     prompt_batch = batch[caption_column]
  164.  
  165.     captions = []
  166.     for caption in prompt_batch:
  167.         if random.random() < proportion_empty_prompts:
  168.             captions.append("")
  169.         elif isinstance(caption, str):
  170.             captions.append(caption)
  171.         elif isinstance(caption, (list, np.ndarray)):
  172.             # take a random caption if there are multiple
  173.             captions.append(random.choice(caption) if is_train else caption[0])
  174.  
  175.     with torch.no_grad():
  176.         for tokenizer, text_encoder in zip(tokenizers, text_encoders):
  177.             text_inputs = tokenizer(
  178.                 captions,
  179.                 padding="max_length",
  180.                 max_length=tokenizer.model_max_length,
  181.                 truncation=True,
  182.                 return_tensors="pt",
  183.             )
  184.             text_input_ids = text_inputs.input_ids
  185.             prompt_embeds = text_encoder(
  186.                 text_input_ids.to(text_encoder.device),
  187.                 output_hidden_states=True,
  188.             )
  189.  
  190.             # We are only ALWAYS interested in the pooled output of the final text encoder
  191.             pooled_prompt_embeds = prompt_embeds[0]
  192.             prompt_embeds = prompt_embeds.hidden_states[-2]
  193.             bs_embed, seq_len, _ = prompt_embeds.shape
  194.             prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
  195.             prompt_embeds_list.append(prompt_embeds)
  196.  
  197.     prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
  198.     pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
  199.     return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
  200.  
  201.  
  202. def compute_vae_encodings(batch, vae):
  203.     images = batch.pop("pixel_values")
  204.     pixel_values = torch.stack(list(images))
  205.     pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
  206.     pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
  207.  
  208.     with torch.no_grad():
  209.         model_input = vae.encode(pixel_values).latent_dist.sample()
  210.     model_input = model_input * vae.config.scaling_factor
  211.     return {"model_input": model_input.cpu()}
  212.  
  213.  
  214. compute_embeddings_fn = functools.partial(
  215.     encode_prompt,
  216.     text_encoders=text_encoders,
  217.     tokenizers=tokenizers,
  218.     proportion_empty_prompts=0,
  219.     caption_column="text",
  220. )
  221. compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
  222.  
  223. def map_train():
  224.     return train_dataset.map(compute_embeddings_fn, batched=True, batch_size=16, with_rank=True, num_proc=2, keep_in_memory=True)
  225.  
  226.  
  227. if __name__ == "__main__":
  228.     set_start_method("spawn")
  229.     map_train()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement