Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch.nn as nn
- from abc import ABC, abstractmethod
- class BaseModel(nn.Module, ABC):
- def __init__(self):
- super(BaseModel, self).__init__()
- @abstractmethod
- def forward(self, x):
- pass
- def freeze(self):
- for param in self.parameters():
- param.requires_grad = False
- def unfreeze(self):
- for param in self.parameters():
- param.requires_grad = True
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- import torch
- from models.base import BaseModel
- import torch.nn as nn
- import torch.nn.functional as F
- class Decoder(BaseModel):
- def __init__(self, latent_dim):
- super(Decoder, self).__init__()
- self.fc_layers = nn.Sequential(
- nn.Linear(latent_dim, 512),
- nn.ReLU(),
- nn.Linear(512, 1024),
- nn.ReLU(),
- nn.Linear(1024, 256 * 8 * 8),
- nn.ReLU()
- )
- self.deconv_layers = nn.Sequential(
- nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
- nn.Tanh()
- )
- def forward(self, z):
- x = self.fc_layers(z)
- x = x.view(-1, 256, 8, 8)
- x = self.deconv_layers(x)
- return x
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- from models.base import BaseModel
- import torch.nn.functional as F
- import torch.nn as nn
- class Encoder(BaseModel):
- def __init__(self, latent_dim):
- super(Encoder, self).__init__()
- self.conv_layers = nn.Sequential(
- nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
- nn.ReLU(),
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
- nn.ReLU()
- )
- self.fc_layers = nn.Sequential(
- nn.Flatten(),
- nn.Linear(256 * 8 * 8, 1024),
- nn.ReLU(),
- nn.Linear(1024, 512),
- nn.ReLU(),
- nn.Linear(512, latent_dim)
- )
- def forward(self, x):
- x = self.conv_layers(x)
- x = self.fc_layers(x)
- return x
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- import torch
- import torch.nn as nn
- class Discriminator(nn.Module):
- def __init__(self, latent_dim):
- super(Discriminator, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(latent_dim, 256),
- nn.ReLU(),
- nn.Linear(256, 1),
- nn.Sigmoid()
- )
- def forward(self, z):
- return self.model(z)
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- #-----------------------------------------------------------------------------------------
- import torch
- import torch.nn as nn
- class Classifier(nn.Module):
- def __init__(self, latent_dim, num_classes):
- super(Classifier, self).__init__()
- self.model = nn.Sequential(
- nn.Linear(latent_dim, 256),
- nn.ReLU(),
- nn.Linear(256, num_classes)
- )
- def forward(self, z):
- return self.model(z)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement