Advertisement
kopyl

Untitled

May 31st, 2023
921
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.04 KB | None | 0 0
  1. import os
  2. import shutil
  3. import subprocess
  4. import time
  5. import gc
  6. import sys
  7.  
  8. import torch
  9. import random
  10. from collections import OrderedDict
  11. from types import SimpleNamespace
  12. from omegaconf import OmegaConf
  13.  
  14. sys.path.insert(0, "src")
  15. import clip
  16.  
  17. from ldm.util import instantiate_from_config
  18. from helpers.render import (
  19.     render_animation,
  20.     render_input_video,
  21.     render_image_batch,
  22.     render_interpolation,
  23. )
  24. from helpers.model_load import make_linear_decode
  25. from helpers.aesthetics import load_aesthetics_model
  26.  
  27.  
  28. MODEL_CACHE = os.getenv("SD_MODELS_DIR", "diffusion_models_cache")
  29.  
  30.  
  31. class Predictor:
  32.     def predict(
  33.         self,
  34.         model_checkpoint="Protogen_V2.2.ckpt",
  35.         max_frames=10,
  36.         animation_prompts="0: a beautiful apple, trending on Artstation | 5: a beautiful banana, trending on Artstation",
  37.         width=512,
  38.         height=512,
  39.         num_inference_steps=50,
  40.         guidance_scale=7,
  41.         sampler="euler_ancestral",
  42.         seed=None,
  43.         fps=15,
  44.         clip_name="ViT-L/14",
  45.         use_init=False,
  46.         init_image=None,
  47.         strength=0.5,
  48.         use_mask=False,
  49.         mask_file=None,
  50.         invert_mask=False,
  51.         animation_mode="2D",
  52.         border="replicate",
  53.         angle="0:(0)",
  54.         zoom="0:(1.04)",
  55.         translation_x="0:(10*sin(2*3.14*t/10))",
  56.         translation_y="0:(0)",
  57.         translation_z="0:(10)",
  58.         rotation_3d_x="0:(0)",
  59.         rotation_3d_y="0:(0)",
  60.         rotation_3d_z="0:(0)",
  61.         flip_2d_perspective=False,
  62.         perspective_flip_theta="0:(0)",
  63.         perspective_flip_phi="0:(t%15)",
  64.         perspective_flip_gamma="0:(0)",
  65.         perspective_flip_fv="0:(53)",
  66.         noise_schedule="0: (0.02)",
  67.         strength_schedule="0: (0.65)",
  68.         contrast_schedule="0: (1.0)",
  69.         hybrid_video_comp_alpha_schedule="0:(1)",
  70.         hybrid_video_comp_mask_blend_alpha_schedule="0:(0.5)",
  71.         hybrid_video_comp_mask_contrast_schedule="0:(1)",
  72.         hybrid_video_comp_mask_auto_contrast_cutoff_high_schedule="0:(100)",
  73.         hybrid_video_comp_mask_auto_contrast_cutoff_low_schedule="0:(0)",
  74.         kernel_schedule="0: (5)",
  75.         sigma_schedule="0: (1.0)",
  76.         amount_schedule="0: (0.2)",
  77.         threshold_schedule="0: (0.0)",
  78.         color_coherence="Match Frame 0 LAB",
  79.         color_coherence_video_every_N_frames=1,
  80.         diffusion_cadence="1",
  81.         use_depth_warping=True,
  82.         midas_weight=0.3,
  83.         near_plane=200,
  84.         far_plane=10000,
  85.         fov=40,
  86.         padding_mode="border",
  87.         sampling_mode="bicubic",
  88.         video_init_path=None,
  89.         extract_nth_frame=1,
  90.         overwrite_extracted_frames=True,
  91.         use_mask_video=False,
  92.         video_mask_path=None,
  93.         hybrid_video_generate_inputframes=False,
  94.         hybrid_video_use_first_frame_as_init_image=True,
  95.         hybrid_video_motion="None",
  96.         hybrid_video_flow_method="Farneback",
  97.         hybrid_video_composite=False,
  98.         hybrid_video_comp_mask_type="None",
  99.         hybrid_video_comp_mask_inverse=False,
  100.         hybrid_video_comp_mask_equalize="None",
  101.         hybrid_video_comp_mask_auto_contrast=False,
  102.         hybrid_video_comp_save_extra_frames=False,
  103.         hybrid_video_use_video_as_mse_image=False,
  104.         interpolate_key_frames=False,
  105.         interpolate_x_frames=4,
  106.         resume_from_timestring=False,
  107.         resume_timestring="",
  108.     ) -> str:
  109.         """Run a single prediction on the model"""
  110.  
  111.         # sanity checks:
  112.         if use_init:
  113.             assert init_image, "Please provide init_image when use_init is set to True."
  114.         if use_mask:
  115.             assert mask_file, "Please provide mask_file when use_mask is set to True."
  116.  
  117.         animation_prompts_dict = {}
  118.         animation_prompts = animation_prompts.split("|")
  119.         assert len(animation_prompts) > 0, "Please provide valid prompt for animation."
  120.         if len(animation_prompts) == 1:
  121.             animation_prompts = {0: animation_prompts[0]}
  122.         else:
  123.             for frame_prompt in animation_prompts:
  124.                 frame_prompt = frame_prompt.split(":")
  125.                 assert (
  126.                     len(frame_prompt) == 2
  127.                 ), "Please follow the 'frame_num: prompt' format."
  128.                 frame_id, prompt = frame_prompt[0].strip(), frame_prompt[1].strip()
  129.                 assert (
  130.                     frame_id.isdigit() and 0 <= int(frame_id) <= max_frames
  131.                 ), "frame_num should be an integer and 0<= frame_num <= max_frames"
  132.                 assert (
  133.                     int(frame_id) not in animation_prompts_dict
  134.                 ), f"Duplicate prompts for frame_num {frame_id}. "
  135.                 assert len(prompt) > 0, "prompt cannot be empty"
  136.                 animation_prompts_dict[int(frame_id)] = prompt
  137.             animation_prompts = OrderedDict(sorted(animation_prompts_dict.items()))
  138.  
  139.         root = {"device": "cuda", "models_path": "models", "configs_path": "configs"}
  140.         model_config = (
  141.             "v2-inference.yaml"
  142.             if model_checkpoint
  143.             in ["v2-1_768-ema-pruned.ckpt", "v2-1_512-ema-pruned.ckpt"]
  144.             else "v1-inference.yaml"
  145.         )
  146.         ckpt_config_path = f"configs/{model_config}"
  147.         ckpt_path = os.path.join(MODEL_CACHE, model_checkpoint)
  148.         local_config = OmegaConf.load(ckpt_config_path)
  149.  
  150.         model = load_model_from_config(local_config, ckpt_path, map_location="cuda")
  151.         model.to("cuda")
  152.         root["model"] = model
  153.  
  154.         root = SimpleNamespace(**root)
  155.  
  156.         autoencoder_version = (
  157.             "sd-v1"  # TODO this will be different for different models
  158.         )
  159.         root.model.linear_decode = make_linear_decode(autoencoder_version, "cuda")
  160.  
  161.         # using some of the default settings for simplicity
  162.         args_dict = {
  163.             "W": width,
  164.             "H": height,
  165.             "bit_depth_output": 8,
  166.             "seed": seed,
  167.             "sampler": sampler,
  168.             "steps": num_inference_steps,
  169.             "scale": guidance_scale,
  170.             "ddim_eta": 0.0,
  171.             "dynamic_threshold": None,
  172.             "static_threshold": None,
  173.             "save_samples": False,
  174.             "save_settings": False,
  175.             "display_samples": False,
  176.             "save_sample_per_step": False,
  177.             "show_sample_per_step": False,
  178.             "prompt_weighting": True,
  179.             "normalize_prompt_weights": True,
  180.             "log_weighted_subprompts": False,
  181.             "n_batch": 1,
  182.             "batch_name": "StableFun",
  183.             "filename_format": "{timestring}_{index}_{prompt}.png",
  184.             "seed_behavior": "iter",
  185.             "seed_iter_N": 1,
  186.             "make_grid": False,
  187.             "grid_rows": 2,
  188.             "outdir": "cog_temp_output",
  189.             "use_init": use_init,
  190.             "strength": strength,
  191.             "strength_0_no_init": True,
  192.             "init_image": init_image,
  193.             "use_mask": use_mask,
  194.             "use_alpha_as_mask": False,
  195.             "mask_file": mask_file,
  196.             "invert_mask": invert_mask,
  197.             "mask_brightness_adjust": 1.0,
  198.             "mask_contrast_adjust": 1.0,
  199.             "overlay_mask": True,
  200.             "mask_overlay_blur": 5,
  201.             "mean_scale": 0,
  202.             "var_scale": 0,
  203.             "exposure_scale": 0,
  204.             "exposure_target": 0.5,
  205.             "colormatch_scale": 0,
  206.             "colormatch_image": "https://www.saasdesign.io/wp-content/uploads/2021/02/palette-3-min-980x588.png",
  207.             "colormatch_n_colors": 4,
  208.             "ignore_sat_weight": 0,
  209.             "clip_name": clip_name,
  210.             "clip_scale": 0,
  211.             "aesthetics_scale": 0,
  212.             "cutn": 1,
  213.             "cut_pow": 0.0001,
  214.             "init_mse_scale": 0,
  215.             "init_mse_image": "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg",
  216.             "blue_scale": 0,
  217.             "gradient_wrt": "x0_pred",
  218.             "gradient_add_to": "both",
  219.             "decode_method": "linear",
  220.             "grad_threshold_type": "dynamic",
  221.             "clamp_grad_threshold": 0.2,
  222.             "clamp_start": 0.2,
  223.             "clamp_stop": 0.01,
  224.             "grad_inject_timing": [1, 2, 3, 4, 5, 6, 7, 8, 9],
  225.             "cond_uncond_sync": True,
  226.             "n_samples": 1,
  227.             "precision": "autocast",
  228.             "C": 4,
  229.             "f": 8,
  230.             "prompt": "",
  231.             "timestring": "",
  232.             "init_latent": None,
  233.             "init_sample": None,
  234.             "init_sample_raw": None,
  235.             "mask_sample": None,
  236.             "init_c": None,
  237.             "seed_internal": 0,
  238.         }
  239.  
  240.         anim_args_dict = {
  241.             "animation_mode": animation_mode,
  242.             "max_frames": max_frames,
  243.             "border": border,
  244.             "angle": angle,
  245.             "zoom": zoom,
  246.             "translation_x": translation_x,
  247.             "translation_y": translation_y,
  248.             "translation_z": translation_z,
  249.             "rotation_3d_x": rotation_3d_x,
  250.             "rotation_3d_y": rotation_3d_y,
  251.             "rotation_3d_z": rotation_3d_z,
  252.             "flip_2d_perspective": flip_2d_perspective,
  253.             "perspective_flip_theta": perspective_flip_theta,
  254.             "perspective_flip_phi": perspective_flip_phi,
  255.             "perspective_flip_gamma": perspective_flip_gamma,
  256.             "perspective_flip_fv": perspective_flip_fv,
  257.             "noise_schedule": noise_schedule,
  258.             "strength_schedule": strength_schedule,
  259.             "contrast_schedule": contrast_schedule,
  260.             "hybrid_video_comp_alpha_schedule": hybrid_video_comp_alpha_schedule,
  261.             "hybrid_video_comp_mask_blend_alpha_schedule": hybrid_video_comp_mask_blend_alpha_schedule,
  262.             "hybrid_video_comp_mask_contrast_schedule": hybrid_video_comp_mask_contrast_schedule,
  263.             "hybrid_video_comp_mask_auto_contrast_cutoff_high_schedule": hybrid_video_comp_mask_auto_contrast_cutoff_high_schedule,
  264.             "hybrid_video_comp_mask_auto_contrast_cutoff_low_schedule": hybrid_video_comp_mask_auto_contrast_cutoff_low_schedule,
  265.             "kernel_schedule": kernel_schedule,
  266.             "sigma_schedule": sigma_schedule,
  267.             "amount_schedule": amount_schedule,
  268.             "threshold_schedule": threshold_schedule,
  269.             "color_coherence": color_coherence,
  270.             "color_coherence_video_every_N_frames": color_coherence_video_every_N_frames,
  271.             "diffusion_cadence": diffusion_cadence,
  272.             "use_depth_warping": use_depth_warping,
  273.             "midas_weight": midas_weight,
  274.             "near_plane": near_plane,
  275.             "far_plane": far_plane,
  276.             "fov": fov,
  277.             "padding_mode": padding_mode,
  278.             "sampling_mode": sampling_mode,
  279.             "save_depth_maps": False,
  280.             "video_init_path": str(video_init_path),
  281.             "extract_nth_frame": extract_nth_frame,
  282.             "overwrite_extracted_frames": overwrite_extracted_frames,
  283.             "use_mask_video": use_mask_video,
  284.             "video_mask_path": str(video_mask_path),
  285.             "hybrid_video_generate_inputframes": hybrid_video_generate_inputframes,
  286.             "hybrid_video_use_first_frame_as_init_image": hybrid_video_use_first_frame_as_init_image,
  287.             "hybrid_video_motion": hybrid_video_motion,
  288.             "hybrid_video_flow_method": hybrid_video_flow_method,
  289.             "hybrid_video_composite": hybrid_video_composite,
  290.             "hybrid_video_comp_mask_type": hybrid_video_comp_mask_type,
  291.             "hybrid_video_comp_mask_inverse": hybrid_video_comp_mask_inverse,
  292.             "hybrid_video_comp_mask_equalize": hybrid_video_comp_mask_equalize,
  293.             "hybrid_video_comp_mask_auto_contrast": hybrid_video_comp_mask_auto_contrast,
  294.             "hybrid_video_comp_save_extra_frames": hybrid_video_comp_save_extra_frames,
  295.             "hybrid_video_use_video_as_mse_image": hybrid_video_use_video_as_mse_image,
  296.             "interpolate_key_frames": interpolate_key_frames,
  297.             "interpolate_x_frames": interpolate_x_frames,
  298.             "resume_from_timestring": resume_from_timestring,
  299.             "resume_timestring": resume_timestring,
  300.         }
  301.  
  302.         args = SimpleNamespace(**args_dict)
  303.         anim_args = SimpleNamespace(**anim_args_dict)
  304.  
  305.         if os.path.exists(args.outdir):
  306.             shutil.rmtree(args.outdir)
  307.         os.makedirs(args.outdir, exist_ok=True)
  308.  
  309.         args.timestring = time.strftime("%Y%m%d%H%M%S")
  310.         args.strength = max(0.0, min(1.0, args.strength))
  311.  
  312.         # Load clip model if using clip guidance
  313.         if (args.clip_scale > 0) or (args.aesthetics_scale > 0):
  314.             root.clip_model = (
  315.                 clip.load(args.clip_name, jit=False)[0]
  316.                 .eval()
  317.                 .requires_grad_(False)
  318.                 .to(root.device)
  319.             )
  320.             if args.aesthetics_scale > 0:
  321.                 root.aesthetics_model = load_aesthetics_model(args, root)
  322.  
  323.         if args.seed is None:
  324.             args.seed = random.randint(0, 2 ** 32 - 1)
  325.         if not args.use_init:
  326.             args.init_image = None
  327.         if args.sampler == "plms" and (
  328.             args.use_init or anim_args.animation_mode != "None"
  329.         ):
  330.             print(f"Init images aren't supported with PLMS yet, switching to KLMS")
  331.             args.sampler = "klms"
  332.         if args.sampler != "ddim":
  333.             args.ddim_eta = 0
  334.  
  335.         if anim_args.animation_mode == "None":
  336.             anim_args.max_frames = 1
  337.         elif anim_args.animation_mode == "Video Input":
  338.             args.use_init = True
  339.  
  340.         # clean up unused memory
  341.         gc.collect()
  342.         torch.cuda.empty_cache()
  343.  
  344.         # dispatch to appropriate renderer
  345.         if anim_args.animation_mode == "2D" or anim_args.animation_mode == "3D":
  346.             render_animation(args, anim_args, animation_prompts, root)
  347.         elif anim_args.animation_mode == "Video Input":
  348.             render_input_video(args, anim_args, animation_prompts, root)
  349.         elif anim_args.animation_mode == "Interpolation":
  350.             render_interpolation(args, anim_args, animation_prompts, root)
  351.  
  352.         # make video
  353.         image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png")
  354.         mp4_path = f"/tmp/out.mp4"
  355.  
  356.         # make video
  357.         cmd = [
  358.             "ffmpeg",
  359.             "-y",
  360.             "-vcodec",
  361.             "png",
  362.             "-r",
  363.             str(fps),
  364.             "-start_number",
  365.             str(0),
  366.             "-i",
  367.             image_path,
  368.             "-frames:v",
  369.             str(anim_args.max_frames),
  370.             "-c:v",
  371.             "libx264",
  372.             "-vf",
  373.             f"fps={fps}",
  374.             "-pix_fmt",
  375.             "yuv420p",
  376.             "-crf",
  377.             "17",
  378.             "-preset",
  379.             "veryfast",
  380.             mp4_path,
  381.         ]
  382.         process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  383.         stdout, stderr = process.communicate()
  384.         if process.returncode != 0:
  385.             print(stderr)
  386.             raise RuntimeError(stderr)
  387.  
  388.         return mp4_path
  389.  
  390.  
  391. def load_model_from_config(
  392.     config, ckpt, verbose=False, device="cuda", print_flag=False, map_location="cuda"
  393. ):
  394.     print(f"..loading model")
  395.     _, extension = os.path.splitext(ckpt)
  396.     if extension.lower() == ".safetensors":
  397.         import safetensors.torch
  398.  
  399.         pl_sd = safetensors.torch.load_file(ckpt, device=map_location)
  400.     else:
  401.         pl_sd = torch.load(ckpt, map_location=map_location)
  402.     try:
  403.         sd = pl_sd["state_dict"]
  404.     except:
  405.         sd = pl_sd
  406.     torch.set_default_dtype(torch.float16)
  407.     model = instantiate_from_config(config.model)
  408.     torch.set_default_dtype(torch.float32)
  409.     m, u = model.load_state_dict(sd, strict=False)
  410.     if print_flag:
  411.         if len(m) > 0 and verbose:
  412.             print("missing keys:")
  413.             print(m)
  414.         if len(u) > 0 and verbose:
  415.             print("unexpected keys:")
  416.             print(u)
  417.  
  418.     model = model.half().to(device)
  419.     model.eval()
  420.     return model
  421.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement