Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import functools
- import gc
- import logging
- import math
- import os
- import random
- import shutil
- from pathlib import Path
- import accelerate
- import datasets
- import numpy as np
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- import transformers
- from accelerate import Accelerator
- from accelerate.logging import get_logger
- from accelerate.utils import ProjectConfiguration, set_seed
- from accelerate.utils import InitProcessGroupKwargs
- from datasets import load_dataset, DatasetDict, load_from_disk
- from huggingface_hub import create_repo, upload_folder
- from packaging import version
- from torchvision import transforms
- from torchvision.transforms.functional import crop
- from tqdm.auto import tqdm
- from transformers import AutoTokenizer, PretrainedConfig
- import diffusers
- from diffusers import (
- AutoencoderKL,
- DDPMScheduler,
- StableDiffusionXLPipeline,
- UNet2DConditionModel,
- )
- from diffusers.optimization import get_scheduler
- from diffusers.training_utils import EMAModel, compute_snr
- from diffusers.utils import check_min_version, is_wandb_available
- from diffusers.utils.import_utils import is_xformers_available
- from datetime import timedelta
- from multiprocess import set_start_method
- resolution = 128
- dataset_name = "lambdalabs/pokemon-blip-captions"
- cache_dir = "/workspace/dataset-cache"
- pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
- pretrained_vae_model_name_or_path = "madebyollin/sdxl-vae-fp16-fix"
- max_train_samples = None
- check_min_version("0.24.0.dev0")
- logger = get_logger(__name__)
- DATASET_NAME_MAPPING = {
- "lambdalabs/pokemon-blip-captions": ("image", "text"),
- }
- dataset = load_dataset(
- dataset_name,
- cache_dir=cache_dir,
- )
- if max_train_samples is not None:
- dataset["train"] = dataset["train"].select(range(max_train_samples))
- column_names = dataset["train"].column_names
- train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)
- train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
- train_crop = transforms.RandomCrop(resolution)
- def preprocess_train(examples):
- image_column = column_names[0]
- images = [image.convert("RGB") for image in examples[image_column]]
- # image aug
- original_sizes = []
- all_images = []
- crop_top_lefts = []
- for image in images:
- original_sizes.append((image.height, image.width))
- image = train_resize(image)
- y1, x1, h, w = train_crop.get_params(image, (resolution, resolution))
- image = crop(image, y1, x1, h, w)
- crop_top_left = (y1, x1)
- crop_top_lefts.append(crop_top_left)
- image = train_transforms(image)
- all_images.append(image)
- examples["original_sizes"] = original_sizes
- examples["crop_top_lefts"] = crop_top_lefts
- examples["pixel_values"] = all_images
- return examples
- train_dataset = dataset["train"].with_transform(preprocess_train)
- print(type(train_dataset))
- train_dataset
- def import_model_class_from_model_name_or_path(
- pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
- ):
- text_encoder_config = PretrainedConfig.from_pretrained(
- pretrained_model_name_or_path, subfolder=subfolder, revision=revision, device="cuda"
- )
- model_class = text_encoder_config.architectures[0]
- if model_class == "CLIPTextModel":
- from transformers import CLIPTextModel
- return CLIPTextModel
- elif model_class == "CLIPTextModelWithProjection":
- from transformers import CLIPTextModelWithProjection
- return CLIPTextModelWithProjection
- else:
- raise ValueError(f"{model_class} is not supported.")
- text_encoder_cls_one = import_model_class_from_model_name_or_path(
- pretrained_model_name_or_path, None
- )
- text_encoder_cls_two = import_model_class_from_model_name_or_path(
- pretrained_model_name_or_path, None, subfolder="text_encoder_2"
- )
- text_encoder_one = text_encoder_cls_one.from_pretrained(
- pretrained_model_name_or_path, subfolder="text_encoder"
- ).to("cuda")
- text_encoder_two = text_encoder_cls_two.from_pretrained(
- pretrained_model_name_or_path, subfolder="text_encoder_2"
- ).to("cuda")
- vae = AutoencoderKL.from_pretrained(
- pretrained_model_name_or_path, subfolder="vae", device="cuda"
- ).to("cuda")
- tokenizer_one = AutoTokenizer.from_pretrained(
- pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False, device="cuda"
- )
- tokenizer_two = AutoTokenizer.from_pretrained(
- pretrained_model_name_or_path, subfolder="tokenizer_2", device="cuda"
- )
- text_encoders = [text_encoder_one, text_encoder_two]
- tokenizers = [tokenizer_one, tokenizer_two]
- def encode_prompt(batch, rank, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
- print(rank)
- prompt_embeds_list = []
- prompt_batch = batch[caption_column]
- captions = []
- for caption in prompt_batch:
- if random.random() < proportion_empty_prompts:
- captions.append("")
- elif isinstance(caption, str):
- captions.append(caption)
- elif isinstance(caption, (list, np.ndarray)):
- # take a random caption if there are multiple
- captions.append(random.choice(caption) if is_train else caption[0])
- with torch.no_grad():
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- text_inputs = tokenizer(
- captions,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids
- prompt_embeds = text_encoder(
- text_input_ids.to(text_encoder.device),
- output_hidden_states=True,
- )
- # We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
- prompt_embeds = prompt_embeds.hidden_states[-2]
- bs_embed, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
- prompt_embeds_list.append(prompt_embeds)
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
- return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
- def compute_vae_encodings(batch, vae):
- images = batch.pop("pixel_values")
- pixel_values = torch.stack(list(images))
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
- with torch.no_grad():
- model_input = vae.encode(pixel_values).latent_dist.sample()
- model_input = model_input * vae.config.scaling_factor
- return {"model_input": model_input.cpu()}
- compute_embeddings_fn = functools.partial(
- encode_prompt,
- text_encoders=text_encoders,
- tokenizers=tokenizers,
- proportion_empty_prompts=0,
- caption_column="text",
- )
- compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
- def map_train():
- return train_dataset.map(compute_embeddings_fn, batched=True, batch_size=2, with_rank=True, num_proc=2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement