Advertisement
HawkeyeHS

NASWOT

Feb 8th, 2025 (edited)
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.02 KB | Software | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4.  
  5.  
  6. class SampleNet(nn.Module):
  7.     def __init__(self):
  8.         super(SampleNet, self).__init__()
  9.         self.fc1 = nn.Linear(10, 20)
  10.         self.relu = nn.ReLU()
  11.         self.fc2 = nn.Linear(20, 5)
  12.  
  13.     def forward(self, x):
  14.         x = self.fc1(x)
  15.         x = self.relu(x)
  16.         x = self.fc2(x)
  17.         return x
  18.    
  19. class ComplexNet(nn.Module):
  20.     def __init__(self, input_size=10, hidden_sizes=[64, 128, 256], output_size=5, dropout_prob=0.3):
  21.         super(ComplexNet, self).__init__()
  22.        
  23.         self.layers = nn.ModuleList()
  24.         prev_size = input_size
  25.        
  26.         for hidden_size in hidden_sizes:
  27.             self.layers.append(nn.Linear(prev_size, hidden_size))
  28.             self.layers.append(nn.BatchNorm1d(hidden_size))  # Batch normalization for stability
  29.             self.layers.append(nn.ReLU())
  30.             self.layers.append(nn.Dropout(dropout_prob))  # Dropout for regularization
  31.             prev_size = hidden_size
  32.        
  33.         self.output_layer = nn.Linear(prev_size, output_size)
  34.  
  35.     def forward(self, x):
  36.         for layer in self.layers:
  37.             x = layer(x)
  38.         x = self.output_layer(x)
  39.         return x
  40.  
  41.  
  42. def forward_hook(module, input, output):
  43.     try:
  44.         if not module.visited_backwards:
  45.             return
  46.         if isinstance(input, tuple):
  47.             input = input[0]  
  48.  
  49.         binary_activation = (input > 0).float()
  50.         K = binary_activation @ binary_activation.T
  51.         K2 = (1. - binary_activation) @ (1. - binary_activation.T)
  52.  
  53.         network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
  54.        
  55.     except:
  56.         pass
  57.    
  58. def counting_backward_hook(module, input, output):
  59.     module.visited_backwards = True
  60.  
  61.  
  62. def hooklogdet(K, labels=None):
  63.     s, ld = np.linalg.slogdet(K)
  64.     return ld
  65.  
  66.  
  67. def compute_score(network, dataloader, device, maxOfn=1):
  68.     network = network.to(device)
  69.     network.eval()
  70.     scores = []
  71.  
  72.     for _ in range(maxOfn):
  73.         data_iterator = iter(dataloader)
  74.         x, target = next(data_iterator)
  75.         x, target = x.to(device), target.to(device)
  76.  
  77.         network.K = np.zeros((x.shape[0], x.shape[0]))
  78.  
  79.         for name, module in network.named_modules():
  80.             if 'ReLU' in str(type(module)):
  81.                 # module.K = network.K  
  82.                 module.register_forward_hook(forward_hook)
  83.                 module.register_backward_hook(counting_backward_hook)
  84.        
  85.         network(x.to(device))
  86.         print(f"det(K): {np.linalg.det(network.K)}")
  87.         print(network.K)
  88.         score = hooklogdet(network.K)
  89.         scores.append(score)
  90.        
  91.     print(scores)
  92.     return np.mean(scores)
  93.  
  94.  
  95. if __name__ == "__main__":
  96.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  97.     network = ComplexNet().to(device)
  98.     dataloader = [(torch.randn(128, 10), torch.randint(0, 5, (128,)))]
  99.     score = compute_score(network, dataloader, device)
  100.     print(f"NASWOT score: {score}")
  101.  
Tags: ML
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement