Advertisement
aquiem

Untitled

Apr 13th, 2025
202
0
178 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.65 KB | Source Code | 0 0
  1. import torch.nn as nn
  2. from abc import ABC, abstractmethod
  3.  
  4. class BaseModel(nn.Module, ABC):
  5.     def __init__(self):
  6.         super(BaseModel, self).__init__()
  7.  
  8.     @abstractmethod
  9.     def forward(self, x):
  10.         pass
  11.  
  12.     def freeze(self):
  13.         for param in self.parameters():
  14.             param.requires_grad = False
  15.  
  16.     def unfreeze(self):
  17.         for param in self.parameters():
  18.             param.requires_grad = True
  19.  
  20.  
  21. #-----------------------------------------------------------------------------------------
  22. #-----------------------------------------------------------------------------------------
  23. #-----------------------------------------------------------------------------------------
  24. #-----------------------------------------------------------------------------------------
  25.  
  26.  
  27. import torch
  28. from models.base import BaseModel
  29. import torch.nn as nn
  30. import torch.nn.functional as F
  31.  
  32. class Decoder(BaseModel):
  33.     def __init__(self, latent_dim):
  34.         super(Decoder, self).__init__()
  35.  
  36.         self.fc_layers = nn.Sequential(
  37.             nn.Linear(latent_dim, 512),
  38.             nn.ReLU(),
  39.             nn.Linear(512, 1024),
  40.             nn.ReLU(),
  41.             nn.Linear(1024, 256 * 8 * 8),
  42.             nn.ReLU()
  43.         )
  44.  
  45.         self.deconv_layers = nn.Sequential(
  46.             nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  
  47.             nn.ReLU(),
  48.             nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  
  49.             nn.ReLU(),
  50.             nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    
  51.             nn.ReLU(),
  52.             nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),    
  53.             nn.Tanh()
  54.         )
  55.  
  56.     def forward(self, z):
  57.         x = self.fc_layers(z)
  58.         x = x.view(-1, 256, 8, 8)
  59.         x = self.deconv_layers(x)
  60.         return x
  61.  
  62. #-----------------------------------------------------------------------------------------
  63. #-----------------------------------------------------------------------------------------
  64. #-----------------------------------------------------------------------------------------
  65. #-----------------------------------------------------------------------------------------
  66.  
  67. from models.base import BaseModel
  68. import torch.nn.functional as F
  69. import torch.nn as nn
  70.  
  71. class Encoder(BaseModel):
  72.     def __init__(self, latent_dim):
  73.         super(Encoder, self).__init__()
  74.         self.conv_layers = nn.Sequential(
  75.             nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
  76.             nn.ReLU(),
  77.             nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
  78.             nn.ReLU(),
  79.             nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
  80.             nn.ReLU(),
  81.             nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
  82.             nn.ReLU()
  83.         )
  84.  
  85.         self.fc_layers = nn.Sequential(
  86.             nn.Flatten(),
  87.             nn.Linear(256 * 8 * 8, 1024),
  88.             nn.ReLU(),
  89.             nn.Linear(1024, 512),
  90.             nn.ReLU(),
  91.             nn.Linear(512, latent_dim)
  92.         )
  93.  
  94.  
  95.     def forward(self, x):
  96.         x = self.conv_layers(x)
  97.         x = self.fc_layers(x)
  98.         return x
  99.  
  100.  
  101. #-----------------------------------------------------------------------------------------
  102. #-----------------------------------------------------------------------------------------
  103. #-----------------------------------------------------------------------------------------
  104. #-----------------------------------------------------------------------------------------
  105.  
  106. import torch
  107. import torch.nn as nn
  108.  
  109. class Discriminator(nn.Module):
  110.     def __init__(self, latent_dim):
  111.         super(Discriminator, self).__init__()
  112.         self.model = nn.Sequential(
  113.             nn.Linear(latent_dim, 256),
  114.             nn.ReLU(),
  115.             nn.Linear(256, 1),
  116.             nn.Sigmoid()
  117.         )
  118.  
  119.     def forward(self, z):
  120.         return self.model(z)
  121.  
  122. #-----------------------------------------------------------------------------------------
  123. #-----------------------------------------------------------------------------------------
  124. #-----------------------------------------------------------------------------------------
  125. #-----------------------------------------------------------------------------------------
  126.  
  127. import torch
  128. import torch.nn as nn
  129.  
  130. class Classifier(nn.Module):
  131.     def __init__(self, latent_dim, num_classes):
  132.         super(Classifier, self).__init__()
  133.         self.model = nn.Sequential(
  134.             nn.Linear(latent_dim, 256),
  135.             nn.ReLU(),
  136.             nn.Linear(256, num_classes)
  137.         )
  138.  
  139.     def forward(self, z):
  140.         return self.model(z)
  141.  
  142.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement