Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import numpy as np
- class SampleNet(nn.Module):
- def __init__(self):
- super(SampleNet, self).__init__()
- self.fc1 = nn.Linear(10, 20)
- self.relu = nn.ReLU()
- self.fc2 = nn.Linear(20, 5)
- def forward(self, x):
- x = self.fc1(x)
- x = self.relu(x)
- x = self.fc2(x)
- return x
- class ComplexNet(nn.Module):
- def __init__(self, input_size=10, hidden_sizes=[64, 128, 256], output_size=5, dropout_prob=0.3):
- super(ComplexNet, self).__init__()
- self.layers = nn.ModuleList()
- prev_size = input_size
- for hidden_size in hidden_sizes:
- self.layers.append(nn.Linear(prev_size, hidden_size))
- self.layers.append(nn.BatchNorm1d(hidden_size)) # Batch normalization for stability
- self.layers.append(nn.ReLU())
- self.layers.append(nn.Dropout(dropout_prob)) # Dropout for regularization
- prev_size = hidden_size
- self.output_layer = nn.Linear(prev_size, output_size)
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
- x = self.output_layer(x)
- return x
- def forward_hook(module, input, output):
- try:
- if not module.visited_backwards:
- return
- if isinstance(input, tuple):
- input = input[0]
- binary_activation = (input > 0).float()
- K = binary_activation @ binary_activation.T
- K2 = (1. - binary_activation) @ (1. - binary_activation.T)
- network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
- except:
- pass
- def counting_backward_hook(module, input, output):
- module.visited_backwards = True
- def hooklogdet(K, labels=None):
- s, ld = np.linalg.slogdet(K)
- return ld
- def compute_score(network, dataloader, device, maxOfn=1):
- network = network.to(device)
- network.eval()
- scores = []
- for _ in range(maxOfn):
- data_iterator = iter(dataloader)
- x, target = next(data_iterator)
- x, target = x.to(device), target.to(device)
- network.K = np.zeros((x.shape[0], x.shape[0]))
- for name, module in network.named_modules():
- if 'ReLU' in str(type(module)):
- # module.K = network.K
- module.register_forward_hook(forward_hook)
- module.register_backward_hook(counting_backward_hook)
- network(x.to(device))
- print(f"det(K): {np.linalg.det(network.K)}")
- print(network.K)
- score = hooklogdet(network.K)
- scores.append(score)
- print(scores)
- return np.mean(scores)
- if __name__ == "__main__":
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- network = ComplexNet().to(device)
- dataloader = [(torch.randn(128, 10), torch.randint(0, 5, (128,)))]
- score = compute_score(network, dataloader, device)
- print(f"NASWOT score: {score}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement