Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from torchvision import models
- modellist = [models.vgg19, models.resnet50, models.densenet121,models.efficientnet_b0]
- modelnames = [model.__name__ for model in modellist]
- class BackboneModel(nn.Module):
- def __init__(self, Backbone, n_channels=1, num_classes=10):
- super(BackboneModel, self).__init__()
- self.n_channels = n_channels
- self.num_classes = num_classes
- # Initialize backbone model
- self.model = Backbone(weights=None)
- # Modify layers based on backbone type
- if isinstance(self.model, models.vgg.VGG):
- self.modify_vgg()
- elif isinstance(self.model, models.resnet.ResNet):
- self.modify_resnet()
- elif isinstance(self.model, models.densenet.DenseNet):
- self.modify_densenet()
- elif isinstance(self.model, models.efficientnet.EfficientNet):
- self.modify_efficientnet()
- else:
- raise NotImplementedError("Backbone type not supported.")
- def modify_vgg(self):
- # Modify first layer to accept n_channels input
- self.model.features[0] = nn.Conv2d(self.n_channels, 64, kernel_size=3, padding=1)
- # Modify classification layer
- num_features = self.model.classifier[-1].in_features
- self.model.classifier[-1] = nn.Linear(num_features, self.num_classes)
- def modify_resnet(self):
- # Modify first convolutional layer to accept n_channels input
- self.model.conv1 = nn.Conv2d(self.n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
- # Modify classification layer
- num_features = self.model.fc.in_features
- self.model.fc = nn.Linear(num_features, self.num_classes)
- def modify_densenet(self):
- # Modify first convolutional layer to accept n_channels input
- self.model.features.conv0 = nn.Conv2d(self.n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
- # Modify classification layer
- num_features = self.model.classifier.in_features
- self.model.classifier = nn.Linear(num_features, self.num_classes)
- def modify_efficientnet(self):
- # Modify first convolutional layer to accept n_channels input
- self.model.features[0][0] = nn.Conv2d(self.n_channels, 32, kernel_size=3, stride=2, padding=1, bias=False)
- # Modify classification layer
- num_features = self.model.classifier[-1].in_features
- self.model.classifier[-1] = nn.Linear(num_features, self.num_classes)
- def forward(self, x):
- return self.model(x)
- # Example usage:
- model = BackboneModel(modellist[3], n_channels=1, num_classes=10)
- input_tensor = torch.randn(1, 1, 225, 225) # Example input tensor with 1 channel
- output = model(input_tensor)
- print(output.shape) # Example output shape
- # print(model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement