Advertisement
kopyl

single-gpu

Jan 27th, 2024 (edited)
1,161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.82 KB | None | 0 0
  1. import functools
  2. import gc
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. from pathlib import Path
  9.  
  10. import accelerate
  11. import datasets
  12. import numpy as np
  13. import torch
  14. import torch.nn.functional as F
  15. import torch.utils.checkpoint
  16. import transformers
  17. from accelerate import Accelerator
  18. from accelerate.logging import get_logger
  19. from accelerate.utils import ProjectConfiguration, set_seed
  20. from accelerate.utils import InitProcessGroupKwargs
  21. from datasets import load_dataset, DatasetDict, load_from_disk
  22. from huggingface_hub import create_repo, upload_folder
  23. from packaging import version
  24. from torchvision import transforms
  25. from torchvision.transforms.functional import crop
  26. from tqdm.auto import tqdm
  27. from transformers import AutoTokenizer, PretrainedConfig
  28. from multiprocess import set_start_method
  29.  
  30. import diffusers
  31. from diffusers import (
  32.     AutoencoderKL,
  33.     DDPMScheduler,
  34.     StableDiffusionXLPipeline,
  35.     UNet2DConditionModel,
  36. )
  37. from diffusers.optimization import get_scheduler
  38. from diffusers.training_utils import EMAModel, compute_snr
  39. from diffusers.utils import check_min_version, is_wandb_available
  40. from diffusers.utils.import_utils import is_xformers_available
  41. from datetime import timedelta
  42. from multiprocess import set_start_method
  43.  
  44.  
  45. from typing import Callable
  46.  
  47.  
  48. kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(days=10))]
  49. accelerator = Accelerator(kwargs_handlers=kwargs_handlers)
  50.  
  51.  
  52. resolution = 128
  53. dataset_name = "lambdalabs/pokemon-blip-captions"
  54. cache_dir = "/workspace/dataset-cache"
  55. pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
  56. pretrained_vae_model_name_or_path = "madebyollin/sdxl-vae-fp16-fix"
  57.  
  58. max_train_samples = None
  59.  
  60.  
  61. check_min_version("0.24.0.dev0")
  62. logger = get_logger(__name__)
  63. DATASET_NAME_MAPPING = {
  64.     "lambdalabs/pokemon-blip-captions": ("image", "text"),
  65. }
  66.  
  67. dataset = load_dataset(
  68.     dataset_name,
  69.     cache_dir=cache_dir,
  70. )
  71.  
  72.  
  73. if max_train_samples is not None:
  74.     dataset["train"] = dataset["train"].select(range(max_train_samples))
  75.  
  76.  
  77. column_names = dataset["train"].column_names
  78.  
  79.  
  80. train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)
  81. train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  82. train_crop = transforms.RandomCrop(resolution)
  83.  
  84.  
  85. def preprocess_train(examples):
  86.     image_column = column_names[0]
  87.     images = [image.convert("RGB") for image in examples[image_column]]
  88.     # image aug
  89.     original_sizes = []
  90.     all_images = []
  91.     crop_top_lefts = []
  92.     for image in images:
  93.         original_sizes.append((image.height, image.width))
  94.         image = train_resize(image)
  95.  
  96.         y1, x1, h, w = train_crop.get_params(image, (resolution, resolution))
  97.         image = crop(image, y1, x1, h, w)
  98.  
  99.         crop_top_left = (y1, x1)
  100.         crop_top_lefts.append(crop_top_left)
  101.         image = train_transforms(image)
  102.         all_images.append(image)
  103.  
  104.     examples["original_sizes"] = original_sizes
  105.     examples["crop_top_lefts"] = crop_top_lefts
  106.     examples["pixel_values"] = all_images
  107.     return examples
  108.  
  109.  
  110. train_dataset = dataset["train"].with_transform(preprocess_train)
  111. print(type(train_dataset))
  112. train_dataset
  113.  
  114.  
  115. def import_model_class_from_model_name_or_path(
  116.     pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
  117. ):
  118.     text_encoder_config = PretrainedConfig.from_pretrained(
  119.         pretrained_model_name_or_path, subfolder=subfolder, revision=revision
  120.     )
  121.     model_class = text_encoder_config.architectures[0]
  122.  
  123.     if model_class == "CLIPTextModel":
  124.         from transformers import CLIPTextModel
  125.  
  126.         return CLIPTextModel
  127.     elif model_class == "CLIPTextModelWithProjection":
  128.         from transformers import CLIPTextModelWithProjection
  129.  
  130.         return CLIPTextModelWithProjection
  131.     else:
  132.         raise ValueError(f"{model_class} is not supported.")
  133.  
  134.  
  135. text_encoder_cls_one = import_model_class_from_model_name_or_path(
  136.     pretrained_model_name_or_path, None
  137. )
  138. text_encoder_cls_two = import_model_class_from_model_name_or_path(
  139.     pretrained_model_name_or_path, None, subfolder="text_encoder_2"
  140. )
  141.  
  142. text_encoder_one = text_encoder_cls_one.from_pretrained(
  143.     pretrained_model_name_or_path, subfolder="text_encoder"
  144. )
  145. text_encoder_one.to(accelerator.device)
  146.  
  147. text_encoder_two = text_encoder_cls_two.from_pretrained(
  148.     pretrained_model_name_or_path, subfolder="text_encoder_2"
  149. )
  150. text_encoder_two.to(accelerator.device)
  151.  
  152. vae = AutoencoderKL.from_pretrained(
  153.     pretrained_model_name_or_path, subfolder="vae"
  154. )
  155. vae.to(accelerator.device)
  156.  
  157. tokenizer_one = AutoTokenizer.from_pretrained(
  158.     pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False
  159. )
  160. tokenizer_two = AutoTokenizer.from_pretrained(
  161.     pretrained_model_name_or_path, subfolder="tokenizer_2"
  162. )
  163.  
  164.  
  165. text_encoders = [text_encoder_one, text_encoder_two]
  166. tokenizers = [tokenizer_one, tokenizer_two]
  167.  
  168.  
  169. def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
  170.     prompt_embeds_list = []
  171.     prompt_batch = batch[caption_column]
  172.  
  173.     captions = []
  174.     for caption in prompt_batch:
  175.         if random.random() < proportion_empty_prompts:
  176.             captions.append("")
  177.         elif isinstance(caption, str):
  178.             captions.append(caption)
  179.         elif isinstance(caption, (list, np.ndarray)):
  180.             # take a random caption if there are multiple
  181.             captions.append(random.choice(caption) if is_train else caption[0])
  182.  
  183.     with torch.no_grad():
  184.         for tokenizer, text_encoder in zip(tokenizers, text_encoders):
  185.             text_inputs = tokenizer(
  186.                 captions,
  187.                 padding="max_length",
  188.                 max_length=tokenizer.model_max_length,
  189.                 truncation=True,
  190.                 return_tensors="pt",
  191.             )
  192.             text_input_ids = text_inputs.input_ids
  193.             prompt_embeds = text_encoder(
  194.                 text_input_ids.to(text_encoder.device),
  195.                 output_hidden_states=True,
  196.             )
  197.  
  198.             # We are only ALWAYS interested in the pooled output of the final text encoder
  199.             pooled_prompt_embeds = prompt_embeds[0]
  200.             prompt_embeds = prompt_embeds.hidden_states[-2]
  201.             bs_embed, seq_len, _ = prompt_embeds.shape
  202.             prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
  203.             prompt_embeds_list.append(prompt_embeds)
  204.  
  205.     prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
  206.     pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
  207.     return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
  208.  
  209.  
  210. def compute_vae_encodings(batch, vae):
  211.     images = batch.pop("pixel_values")
  212.     pixel_values = torch.stack(list(images))
  213.     pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
  214.     pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
  215.  
  216.     with torch.no_grad():
  217.         model_input = vae.encode(pixel_values).latent_dist.sample()
  218.     model_input = model_input * vae.config.scaling_factor
  219.     return {"model_input": model_input.cpu()}
  220.  
  221.  
  222. compute_embeddings_fn = functools.partial(
  223.     encode_prompt,
  224.     text_encoders=text_encoders,
  225.     tokenizers=tokenizers,
  226.     proportion_empty_prompts=0,
  227.     caption_column="text",
  228. )
  229. compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
  230.  
  231. def map_train():
  232.     return train_dataset.map(compute_embeddings_fn, batched=True, batch_size=4, num_proc=1)
  233.  
  234.  
  235. datasets.disable_caching()
  236. cache_path = "/home/.cache"
  237.  
  238. def dataset_map_multi_worker(
  239.     dataset: datasets.Dataset, map_fn: Callable, *args, **kwargs
  240. ) -> datasets.Dataset:
  241.     try:
  242.         rank = torch.distributed.get_rank()
  243.         world_size = torch.distributed.get_world_size()
  244.     except (RuntimeError, ValueError):
  245.         return dataset.map(map_fn, *args, **kwargs)
  246.     ds_shard_filepaths = [
  247.         os.path.join(cache_path, f"{dataset._fingerprint}_subshard_{w}.cache")
  248.         for w in range(0, world_size)
  249.     ]
  250.     print(f"\tworker {rank} saving sub-shard to {ds_shard_filepaths[rank]}")
  251.     ds_shard = dataset.shard(
  252.         num_shards=world_size,
  253.         index=rank,
  254.         contiguous=True,
  255.     )
  256.     ds_shard = ds_shard.map(map_fn, *args, **kwargs)
  257.     ds_shard.with_format(None).save_to_disk(ds_shard_filepaths[rank])
  258.     print("rank", rank, "saving:", ds_shard_filepaths[rank])
  259.     torch.distributed.barrier()
  260.     full_dataset = datasets.concatenate_datasets(
  261.         [datasets.load_from_disk(p) for p in ds_shard_filepaths]
  262.     )
  263.     torch.distributed.barrier()
  264.     print("rank", rank, "deleting:", ds_shard_filepaths[rank])
  265.     shutil.rmtree(ds_shard_filepaths[rank])
  266.     return full_dataset
  267.  
  268.  
  269.  
  270. if __name__ == "__main__":
  271.     map_train()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement