Advertisement
kopyl

forward()

Jul 21st, 2023 (edited)
865
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.69 KB | None | 0 0
  1. def forward(self, samples):
  2.         latents = self.vae.encode(samples["tgt_image"].half()).latent_dist.sample()
  3.         latents = latents * 0.18215
  4.  
  5.         # Sample noise that we'll add to the latents
  6.         noise = torch.randn_like(latents)
  7.         bsz = latents.shape[0]
  8.         # Sample a random timestep for each image
  9.         timesteps = torch.randint(
  10.             0,
  11.             self.noise_scheduler.config.num_train_timesteps,
  12.             (bsz,),
  13.             device=latents.device,
  14.         )
  15.         timesteps = timesteps.long()
  16.  
  17.         # Add noise to the latents according to the noise magnitude at each timestep
  18.         # (this is the forward diffusion process)
  19.         noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
  20.         ctx_embeddings = self.forward_ctx_embeddings(
  21.             input_image=samples["inp_image"], text_input=samples["subject_text"]
  22.         )
  23.  
  24.         # Get the text embedding for conditioning
  25.         input_ids = self.tokenizer(
  26.             samples["caption"],
  27.             padding="do_not_pad",
  28.             truncation=True,
  29.             max_length=self.tokenizer.model_max_length,
  30.             return_tensors="pt",
  31.         ).input_ids.to(self.device)
  32.         encoder_hidden_states = self.text_encoder(
  33.             input_ids=input_ids,
  34.             ctx_embeddings=ctx_embeddings,
  35.             ctx_begin_pos=[self._CTX_BEGIN_POS] * input_ids.shape[0],
  36.         )[0]
  37.  
  38.         # Predict the noise residual
  39.         noise_pred = self.unet(
  40.             noisy_latents.float(), timesteps, encoder_hidden_states
  41.         ).sample
  42.  
  43.         loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
  44.  
  45.         return {"loss": loss}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement