IIS_2023_1/abanin_daniil_lab_7/lab7.py

75 lines
2.3 KiB
Python
Raw Normal View History

2023-10-31 00:50:28 +04:00
from keras import Sequential
from keras.layers import LSTM, Dense, Dropout
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import numpy as np
with open('rus_text.txt', 'r', encoding='utf-8') as file:
text = file.read()
def create_sequences(text, seq_len):
sequences = []
next_chars = []
for i in range(0, len(text) - seq_len):
sequences.append(text[i:i + seq_len])
next_chars.append(text[i + seq_len])
return sequences, next_chars
def get_model_data(seq_length):
tokenizer = Tokenizer(char_level=True)
tokenizer.fit_on_texts([text])
token_text = tokenizer.texts_to_sequences([text])[0]
sequences, next_chars = create_sequences(token_text, seq_length)
vocab_size = len(tokenizer.word_index) + 1
x = pad_sequences(sequences, maxlen=seq_length)
y = np.array(next_chars)
return x, y, vocab_size, tokenizer
def model_build(model, vocab_size):
model.add(LSTM(256, input_shape=(seq_length, 1), return_sequences=True))
model.add(LSTM(128, input_shape=(seq_length, 1)))
model.add(Dropout(0.2, input_shape=(60,)))
model.add(Dense(vocab_size, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Функция для генерации текста
def generate_text(seed_text, gen_length, tokenizer, model):
generated_text = seed_text
for _ in range(gen_length):
sequence = tokenizer.texts_to_sequences([seed_text])[0]
sequence = pad_sequences([sequence], maxlen=seq_length)
prediction = model.predict(sequence)[0]
predicted_index = np.argmax(prediction)
predicted_char = tokenizer.index_word[predicted_index]
generated_text += predicted_char
seed_text += predicted_char
seed_text = seed_text[1:]
return generated_text
seq_length = 10
seed_text = "господин осматривал свою"
# Создание экземпляра Tokenizer и обучение на тексте
X, y, vocab_size, tokenizer = get_model_data(seq_length)
model = Sequential()
model_build(model, vocab_size)
model.fit(X, y, epochs=100, verbose=1)
generated_text = generate_text(seed_text, 200, tokenizer, model)
print(generated_text)