Advertisement
kopyl

Untitled

Jan 29th, 2024
714
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.45 KB | None | 0 0
  1. import functools
  2. import gc
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. from pathlib import Path
  9.  
  10. import accelerate
  11. import datasets
  12. import numpy as np
  13. import torch
  14. import torch.nn.functional as F
  15. import torch.utils.checkpoint
  16. import transformers
  17. from accelerate import Accelerator
  18. from accelerate.logging import get_logger
  19. from accelerate.utils import ProjectConfiguration, set_seed
  20. from accelerate.utils import InitProcessGroupKwargs
  21. from datasets import load_dataset, DatasetDict, load_from_disk
  22. from huggingface_hub import create_repo, upload_folder
  23. from packaging import version
  24. from torchvision import transforms
  25. from torchvision.transforms.functional import crop
  26. from tqdm.auto import tqdm
  27. from transformers import AutoTokenizer, PretrainedConfig
  28.  
  29. import diffusers
  30. from diffusers import (
  31.     AutoencoderKL,
  32.     DDPMScheduler,
  33.     StableDiffusionXLPipeline,
  34.     UNet2DConditionModel,
  35. )
  36. from diffusers.optimization import get_scheduler
  37. from diffusers.training_utils import EMAModel, compute_snr
  38. from diffusers.utils import check_min_version, is_wandb_available
  39. from diffusers.utils.import_utils import is_xformers_available
  40. from datetime import timedelta
  41. from multiprocess import set_start_method
  42.  
  43.  
  44. resolution = 128
  45. dataset_name = "lambdalabs/pokemon-blip-captions"
  46. cache_dir = "/workspace/dataset-cache"
  47. pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
  48. pretrained_vae_model_name_or_path = "madebyollin/sdxl-vae-fp16-fix"
  49.  
  50. max_train_samples = None
  51.  
  52.  
  53. check_min_version("0.24.0.dev0")
  54. logger = get_logger(__name__)
  55. DATASET_NAME_MAPPING = {
  56.     "lambdalabs/pokemon-blip-captions": ("image", "text"),
  57. }
  58.  
  59. dataset = load_dataset(
  60.     dataset_name,
  61.     cache_dir=cache_dir,
  62. )
  63.  
  64.  
  65. if max_train_samples is not None:
  66.     dataset["train"] = dataset["train"].select(range(max_train_samples))
  67.  
  68.  
  69. column_names = dataset["train"].column_names
  70.  
  71.  
  72. train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)
  73. train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  74. train_crop = transforms.RandomCrop(resolution)
  75.  
  76.  
  77. def preprocess_train(examples):
  78.     image_column = column_names[0]
  79.     images = [image.convert("RGB") for image in examples[image_column]]
  80.     # image aug
  81.     original_sizes = []
  82.     all_images = []
  83.     crop_top_lefts = []
  84.     for image in images:
  85.         original_sizes.append((image.height, image.width))
  86.         image = train_resize(image)
  87.  
  88.         y1, x1, h, w = train_crop.get_params(image, (resolution, resolution))
  89.         image = crop(image, y1, x1, h, w)
  90.  
  91.         crop_top_left = (y1, x1)
  92.         crop_top_lefts.append(crop_top_left)
  93.         image = train_transforms(image)
  94.         all_images.append(image)
  95.  
  96.     examples["original_sizes"] = original_sizes
  97.     examples["crop_top_lefts"] = crop_top_lefts
  98.     examples["pixel_values"] = all_images
  99.     return examples
  100.  
  101.  
  102. train_dataset = dataset["train"].with_transform(preprocess_train)
  103. print(type(train_dataset))
  104. train_dataset
  105.  
  106.  
  107. def import_model_class_from_model_name_or_path(
  108.     pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
  109. ):
  110.     text_encoder_config = PretrainedConfig.from_pretrained(
  111.         pretrained_model_name_or_path, subfolder=subfolder, revision=revision, device="cuda"
  112.     )
  113.     model_class = text_encoder_config.architectures[0]
  114.  
  115.     if model_class == "CLIPTextModel":
  116.         from transformers import CLIPTextModel
  117.  
  118.         return CLIPTextModel
  119.     elif model_class == "CLIPTextModelWithProjection":
  120.         from transformers import CLIPTextModelWithProjection
  121.  
  122.         return CLIPTextModelWithProjection
  123.     else:
  124.         raise ValueError(f"{model_class} is not supported.")
  125.  
  126.  
  127. text_encoder_cls_one = import_model_class_from_model_name_or_path(
  128.     pretrained_model_name_or_path, None
  129. )
  130. text_encoder_cls_two = import_model_class_from_model_name_or_path(
  131.     pretrained_model_name_or_path, None, subfolder="text_encoder_2"
  132. )
  133.  
  134. text_encoder_one = text_encoder_cls_one.from_pretrained(
  135.     pretrained_model_name_or_path, subfolder="text_encoder"
  136. )
  137. text_encoder_two = text_encoder_cls_two.from_pretrained(
  138.     pretrained_model_name_or_path, subfolder="text_encoder_2"
  139. )
  140. vae = AutoencoderKL.from_pretrained(
  141.     pretrained_model_name_or_path, subfolder="vae", device="cuda"
  142. )
  143.  
  144. tokenizer_one = AutoTokenizer.from_pretrained(
  145.     pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False, device="cuda"
  146. )
  147. tokenizer_two = AutoTokenizer.from_pretrained(
  148.     pretrained_model_name_or_path, subfolder="tokenizer_2", device="cuda"
  149. )
  150.  
  151.  
  152. text_encoders = [text_encoder_one, text_encoder_two]
  153. tokenizers = [tokenizer_one, tokenizer_two]
  154.  
  155.  
  156. def encode_prompt(batch, rank, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
  157.     print(rank)
  158.     for text_encoder in text_encoders:
  159.         text_encoder.to(f"cuda:{rank}")
  160.    
  161.     prompt_embeds_list = []
  162.     prompt_batch = batch[caption_column]
  163.  
  164.     captions = []
  165.     for caption in prompt_batch:
  166.         if random.random() < proportion_empty_prompts:
  167.             captions.append("")
  168.         elif isinstance(caption, str):
  169.             captions.append(caption)
  170.         elif isinstance(caption, (list, np.ndarray)):
  171.             # take a random caption if there are multiple
  172.             captions.append(random.choice(caption) if is_train else caption[0])
  173.  
  174.     with torch.no_grad():
  175.         for tokenizer, text_encoder in zip(tokenizers, text_encoders):
  176.             text_inputs = tokenizer(
  177.                 captions,
  178.                 padding="max_length",
  179.                 max_length=tokenizer.model_max_length,
  180.                 truncation=True,
  181.                 return_tensors="pt",
  182.             )
  183.             text_input_ids = text_inputs.input_ids
  184.             prompt_embeds = text_encoder(
  185.                 text_input_ids.to(text_encoder.device),
  186.                 output_hidden_states=True,
  187.             )
  188.  
  189.             # We are only ALWAYS interested in the pooled output of the final text encoder
  190.             pooled_prompt_embeds = prompt_embeds[0]
  191.             prompt_embeds = prompt_embeds.hidden_states[-2]
  192.             bs_embed, seq_len, _ = prompt_embeds.shape
  193.             prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
  194.             prompt_embeds_list.append(prompt_embeds)
  195.  
  196.     prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
  197.     pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
  198.     return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
  199.  
  200.  
  201. def compute_vae_encodings(batch, vae):
  202.     images = batch.pop("pixel_values")
  203.     pixel_values = torch.stack(list(images))
  204.     pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
  205.     pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
  206.  
  207.     with torch.no_grad():
  208.         model_input = vae.encode(pixel_values).latent_dist.sample()
  209.     model_input = model_input * vae.config.scaling_factor
  210.     return {"model_input": model_input.cpu()}
  211.  
  212.  
  213. compute_embeddings_fn = functools.partial(
  214.     encode_prompt,
  215.     text_encoders=text_encoders,
  216.     tokenizers=tokenizers,
  217.     proportion_empty_prompts=0,
  218.     caption_column="text",
  219. )
  220. compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
  221.  
  222. def map_train():
  223.     return train_dataset.map(compute_embeddings_fn, batched=True, batch_size=2, with_rank=True, num_proc=2, keep_in_memory=True)
  224.  
  225. map_train()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement