69 lines
3.1 KiB
Python
69 lines
3.1 KiB
Python
import numpy as np
|
|
from keras.models import Sequential
|
|
from keras.layers import Embedding, LSTM, Dense
|
|
from keras.preprocessing.text import Tokenizer
|
|
from keras.preprocessing.sequence import pad_sequences
|
|
|
|
|
|
# Чтение из файла
|
|
def load_text(file_path):
|
|
with open(file_path, encoding='utf-8') as file:
|
|
return file.read()
|
|
|
|
|
|
# Создание токенайзера и последовательностей на основе входного текста
|
|
def create_tokenizer_and_sequences(text):
|
|
tokenizer = Tokenizer()
|
|
tokenizer.fit_on_texts([text])
|
|
total_words = len(tokenizer.word_index) + 1
|
|
input_sequences = []
|
|
for line in text.split('\n'):
|
|
token_list = tokenizer.texts_to_sequences([line])[0]
|
|
for i in range(1, len(token_list)):
|
|
n_gram_sequence = token_list[:i + 1]
|
|
input_sequences.append(n_gram_sequence)
|
|
max_sequence_length = max([len(x) for x in input_sequences])
|
|
input_sequences = pad_sequences(input_sequences, maxlen=max_sequence_length, padding='pre')
|
|
predictors, labels = input_sequences[:, :-1], input_sequences[:, -1]
|
|
return tokenizer, total_words, predictors, labels, max_sequence_length
|
|
|
|
|
|
# Создание и обучение модели
|
|
def create_and_train_model(total_words, max_sequence_length, predictors, labels):
|
|
model = Sequential()
|
|
model.add(Embedding(total_words, 256, input_length=max_sequence_length - 1))
|
|
model.add(LSTM(units=1024))
|
|
model.add(Dense(total_words, activation='softmax'))
|
|
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
|
model.fit(predictors, labels, epochs=100, verbose=1, batch_size=64)
|
|
return model
|
|
|
|
|
|
# Генерация текста
|
|
def generate_text(seed_text, next_words, model, max_sequence_length, tokenizer):
|
|
for _ in range(next_words):
|
|
token_list = tokenizer.texts_to_sequences([seed_text])[0]
|
|
token_list = pad_sequences([token_list], maxlen=max_sequence_length - 1, padding='pre')
|
|
predicted = np.argmax(model.predict(token_list), axis=-1)
|
|
output_word = ""
|
|
for word, index in tokenizer.word_index.items():
|
|
if index == predicted:
|
|
output_word = word
|
|
break
|
|
seed_text += " " + output_word
|
|
return seed_text
|
|
|
|
|
|
# Использование ранее определенных функций
|
|
eng_text = load_text('public/text/eng.txt')
|
|
rus_text = load_text('public/text/rus.txt')
|
|
|
|
tokenizer_eng, total_words_eng, predictors_eng, labels_eng, max_seq_len_eng = create_tokenizer_and_sequences(eng_text)
|
|
tokenizer_rus, total_words_rus, predictors_rus, labels_rus, max_seq_len_rus = create_tokenizer_and_sequences(rus_text)
|
|
|
|
model_eng = create_and_train_model(total_words_eng, max_seq_len_eng, predictors_eng, labels_eng)
|
|
model_rus = create_and_train_model(total_words_rus, max_seq_len_rus, predictors_rus, labels_rus)
|
|
|
|
print(generate_text("\"Event Horizon\"", 50, model_eng, max_seq_len_eng, tokenizer_eng))
|
|
print(generate_text("\"Горизонт событий\"", 50, model_rus, max_seq_len_rus, tokenizer_rus))
|