File size: 905 Bytes
3b6d764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn
import torch

class LSTMModel(nn.Module):
    ## constructor
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(LSTMModel, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size) 

    def forward(self,x, h0=None, c0=None):
        # hidden and state vectors h0 and c0
        if h0 is None or c0 is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)

        out, (hn, cn)  = self.lstm(x, (h0, c0))
        out = self.fc(out)
        return out, (hn, cn)