62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
import numpy as np
|
|
from keras.layers import LSTM, Dense
|
|
from keras.models import Sequential
|
|
from keras.preprocessing.sequence import pad_sequences
|
|
from keras.preprocessing.text import Tokenizer
|
|
|
|
# Чтение текста из файла
|
|
# with open('russian.txt', 'r', encoding='utf-8') as file:
|
|
# text = file.read()
|
|
with open('english.txt', 'r', encoding='utf-8') as file:
|
|
text = file.read()
|
|
|
|
# Обучение Tokenizer на тексте
|
|
tokenizer = Tokenizer(char_level=True)
|
|
tokenizer.fit_on_texts([text])
|
|
sequences = tokenizer.texts_to_sequences([text])[0]
|
|
|
|
# Создание x, y последовательностей
|
|
X_data, y_data = [], []
|
|
seq_length = 10
|
|
for i in range(seq_length, len(sequences)):
|
|
sequence = sequences[i - seq_length:i]
|
|
target = sequences[i]
|
|
X_data.append(sequence)
|
|
y_data.append(target)
|
|
|
|
# Преобразование в массивы
|
|
X_mass = pad_sequences(X_data, maxlen=seq_length)
|
|
y_mass = np.array(y_data)
|
|
|
|
# Создание модели
|
|
vocab_size = len(tokenizer.word_index) + 1
|
|
model = Sequential()
|
|
model.add(LSTM(256, input_shape=(seq_length, 1), return_sequences=True))
|
|
model.add(LSTM(128, input_shape=(seq_length, 1)))
|
|
model.add(Dense(vocab_size, activation='softmax'))
|
|
|
|
# Компиляция
|
|
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
|
|
|
# Обучение
|
|
model.fit(X_mass, y_mass, epochs=100, verbose=1)
|
|
|
|
# Функция генерации
|
|
def generate_text(_text, gen_length):
|
|
generated_text = _text
|
|
for _ in range(gen_length):
|
|
seq = tokenizer.texts_to_sequences([_text])[0]
|
|
seq = pad_sequences([seq], maxlen=seq_length)
|
|
prediction = model.predict(seq)[0]
|
|
predicted_index = np.argmax(prediction)
|
|
predicted_char = tokenizer.index_word[predicted_index]
|
|
generated_text += predicted_char
|
|
_text += predicted_char
|
|
_text = _text[1:]
|
|
return generated_text
|
|
|
|
# Генерация текста
|
|
# _text = "Она сверкала"
|
|
_text = "It sparkled and smoked"
|
|
generate_text = generate_text(_text, 250)
|
|
print(generate_text) |