IIS_2023_1/malkova_anastasia_lab_7/text_rnn.py
2023-11-17 01:50:49 +04:00

30 lines
1.1 KiB
Python

import torch
import torch.nn as nn
class TextRNN(nn.Module):
def __init__(self, input_size, hidden_size, embedding_size, device, n_layers=1):
super(TextRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding_size = embedding_size
self.n_layers = n_layers
self.device = device
self.encoder = nn.Embedding(self.input_size, self.embedding_size)
self.lstm = nn.LSTM(self.embedding_size, self.hidden_size, self.n_layers)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(self.hidden_size, self.input_size)
def forward(self, x, hidden):
x = self.encoder(x).squeeze(2)
out, (ht1, ct1) = self.lstm(x, hidden)
out = self.dropout(out)
x = self.fc(out)
return x, (ht1, ct1)
def init_hidden(self, batch_size=1):
return (torch.zeros(self.n_layers, batch_size, self.hidden_size, requires_grad=True).to(self.device),
torch.zeros(self.n_layers, batch_size, self.hidden_size, requires_grad=True).to(self.device))