Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Define our network class by using the nn.module
- class ResBlockMLP(nn.Module):
- def __init__(self, input_size, output_size):
- super(ResBlockMLP, self).__init__()
- self.norm1 = nn.LayerNorm(input_size)
- self.fc1 = nn.Linear(input_size, input_size//2)
- self.norm2 = nn.LayerNorm(input_size//2)
- self.fc2 = nn.Linear(input_size//2, output_size)
- self.fc3 = nn.Linear(input_size, output_size)
- self.act = nn.ELU()
- def forward(self, x):
- x = self.act(self.norm1(x))
- skip = self.fc3(x)
- x = self.act(self.norm2(self.fc1(x)))
- x = self.fc2(x)
- return x + skip
- class RNN(nn.Module):
- def __init__(self, seq_len, output_size, num_blocks=1, buffer_size=128):
- super(RNN, self).__init__()
- seq_data_len = seq_len * 2
- self.input_mlp = nn.Sequential(nn.Linear(seq_data_len, 4 * seq_data_len),
- nn.ELU(),
- nn.Linear(4 * seq_data_len, 128),
- nn.ELU(),)
- self.rnn = nn.Linear(256, 128)
- blocks = [ResBlockMLP(128, 128) for _ in range(num_blocks)]
- self.res_blocks = nn.Sequential(*blocks)
- self.fc_out = nn.Linear(128, output_size)
- self.fc_buffer = nn.Linear(128, buffer_size)
- self.act = nn.ELU()
- def forward(self, input_seq, buffer_in):
- input_seq = input_seq.reshape(input_seq.shape[0], -1)
- input_vec = self.input_mlp(input_seq)
- # Concatenate the previous step buffer
- x_cat = torch.cat((buffer_in, input_vec), 1)
- x = self.rnn(x_cat)
- x = self.act(self.res_blocks(x))
- return self.fc_out(x), torch.tanh(self.fc_buffer(x))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement