Advertisement
yusufbrima

Torchvision Model Wrapping

May 7th, 2024
776
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.83 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torchvision import models
  4.  
  5. modellist = [models.vgg19, models.resnet50, models.densenet121,models.efficientnet_b0]
  6. modelnames = [model.__name__ for model in modellist]
  7.  
  8.  
  9. class BackboneModel(nn.Module):
  10.     def __init__(self, Backbone, n_channels=1, num_classes=10):
  11.         super(BackboneModel, self).__init__()
  12.         self.n_channels = n_channels
  13.         self.num_classes = num_classes
  14.  
  15.         # Initialize backbone model
  16.         self.model = Backbone(weights=None)
  17.  
  18.         # Modify layers based on backbone type
  19.         if isinstance(self.model, models.vgg.VGG):
  20.             self.modify_vgg()
  21.         elif isinstance(self.model, models.resnet.ResNet):
  22.             self.modify_resnet()
  23.         elif isinstance(self.model, models.densenet.DenseNet):
  24.             self.modify_densenet()
  25.         elif isinstance(self.model, models.efficientnet.EfficientNet):
  26.             self.modify_efficientnet()
  27.         else:
  28.             raise NotImplementedError("Backbone type not supported.")
  29.  
  30.     def modify_vgg(self):
  31.         # Modify first layer to accept n_channels input
  32.         self.model.features[0] = nn.Conv2d(self.n_channels, 64, kernel_size=3, padding=1)
  33.  
  34.         # Modify classification layer
  35.         num_features = self.model.classifier[-1].in_features
  36.         self.model.classifier[-1] = nn.Linear(num_features, self.num_classes)
  37.  
  38.     def modify_resnet(self):
  39.         # Modify first convolutional layer to accept n_channels input
  40.         self.model.conv1 = nn.Conv2d(self.n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
  41.  
  42.         # Modify classification layer
  43.         num_features = self.model.fc.in_features
  44.         self.model.fc = nn.Linear(num_features, self.num_classes)
  45.  
  46.     def modify_densenet(self):
  47.         # Modify first convolutional layer to accept n_channels input
  48.         self.model.features.conv0 = nn.Conv2d(self.n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
  49.  
  50.         # Modify classification layer
  51.         num_features = self.model.classifier.in_features
  52.         self.model.classifier = nn.Linear(num_features, self.num_classes)
  53.  
  54.     def modify_efficientnet(self):
  55.         # Modify first convolutional layer to accept n_channels input
  56.         self.model.features[0][0] = nn.Conv2d(self.n_channels, 32, kernel_size=3, stride=2, padding=1, bias=False)
  57.  
  58.         # Modify classification layer
  59.         num_features = self.model.classifier[-1].in_features
  60.         self.model.classifier[-1] = nn.Linear(num_features, self.num_classes)
  61.  
  62.     def forward(self, x):
  63.         return self.model(x)
  64.  
  65.  
  66.  
  67. # Example usage:
  68. model = BackboneModel(modellist[3], n_channels=1, num_classes=10)
  69. input_tensor = torch.randn(1, 1, 225, 225)  # Example input tensor with 1 channel
  70. output = model(input_tensor)
  71. print(output.shape)  # Example output shape
  72. # print(model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement