Advertisement
kopyl

Untitled

Nov 21st, 2023 (edited)
488
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.05 KB | None | 0 0
  1. import argparse
  2. from onediff.infer_compiler import oneflow_load_compiled
  3. from onediff.infer_compiler import oneflow_compile
  4. from onediff import EulerDiscreteScheduler, rewrite_self_attention
  5. from diffusers import StableDiffusionPipeline
  6. import oneflow as flow
  7. import torch
  8. import os
  9.  
  10.  
  11. def parse_args():
  12.     parser = argparse.ArgumentParser(description="Simple demo of image generation.")
  13.     parser.add_argument("--compiled_graph_path", type=str, default="compiled-graph")
  14.     parser.add_argument(
  15.         "--prompt", type=str, default="an icon of a dog"
  16.     )
  17.     parser.add_argument(
  18.         "--model_id", type=str, default="runwayml/stable-diffusion-v1-5",
  19.     )
  20.     parser.add_argument("--height", type=int, default=512)
  21.     parser.add_argument("--width", type=int, default=512)
  22.     parser.add_argument("--steps", type=int, default=30)
  23.     parser.add_argument("--warmup", type=int, default=1)
  24.     parser.add_argument("--seed", type=int, default=1)
  25.     args = parser.parse_args()
  26.     return args
  27.  
  28.  
  29. args = parse_args()
  30.  
  31. scheduler = EulerDiscreteScheduler.from_pretrained(args.model_id, subfolder="scheduler")
  32. pipe = StableDiffusionPipeline.from_pretrained(
  33.     args.model_id,
  34.     scheduler=scheduler,
  35.     use_auth_token=True,
  36.     revision="fp16",
  37.     variant="fp16",
  38.     torch_dtype=torch.float16,
  39.     safety_checker=None,
  40. )
  41. pipe = pipe.to("cuda")
  42.  
  43.  
  44. compiled_graph_exists = os.path.exists(args.compiled_graph_path)
  45.  
  46. if not compiled_graph_exists:
  47.     rewrite_self_attention(pipe.unet)
  48.     pipe.unet = oneflow_compile(pipe.unet)
  49. else:
  50.     pipe.unet = oneflow_load_compiled(pipe.unet, args.compiled_graph_path)
  51.  
  52. prompt = args.prompt
  53. with flow.autocast("cuda"):
  54.     torch.manual_seed(args.seed)
  55.  
  56.     images = pipe(
  57.         prompt, height=args.height, width=args.width, num_inference_steps=args.steps
  58.     ).images
  59.     if not compiled_graph_exists:
  60.         print("Saving compiled graph")
  61.         pipe.unet.save_graph(args.compiled_graph_path)
  62.     for i, image in enumerate(images):
  63.         image.save(f"{prompt}-of-{i}-seed-{args.seed}.png")
  64.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement