import numpy as np from keras.models import Sequential from keras.layers import Embedding, LSTM, Dense, Dropout from keras.preprocessing.text import Tokenizer from keras.preprocessing.sequence import pad_sequences from keras.utils import to_categorical #Подготовка текста, получение данных для тренировки модели def tokenize(filename): with open(filename, encoding='utf-8') as file: text = file.read() tokenizer = Tokenizer() tokenizer.fit_on_texts([text]) uniq_words_amount = len(tokenizer.word_index) + 1 sequences = [] list_token = tokenizer.texts_to_sequences([text])[0] for i in range(1, len(list_token)): sequences.append(list_token[:i + 1]) max_seq_length = max([len(x) for x in sequences]) sequences = pad_sequences(sequences, maxlen=max_seq_length) x, y = sequences[:, :-1], sequences[:, -1] y = to_categorical(y, num_classes=uniq_words_amount) return max_seq_length, uniq_words_amount, tokenizer, x, y #Создание и тренировка модели def train_model(max_seq_length, uniq_words_amount, x, y, epochs): model = Sequential() model.add(Embedding(uniq_words_amount, 128, input_length=max_seq_length-1)) model.add(LSTM(152, return_sequences=False)) model.add(Dropout(0.2)) model.add(Dense(uniq_words_amount, activation='softmax')) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) model.fit(x, y, epochs=epochs, verbose=1) return model #Генерация текста def generate_text(text, tokenizer, model, max_seq_length): i = 0 while(i < 100): list_token = tokenizer.texts_to_sequences([text])[0] token_list = pad_sequences([list_token], maxlen=max_seq_length-1, padding='pre') predict = model.predict(token_list) predict_index = np.argmax(predict, axis=-1) word = tokenizer.index_word.get(predict_index[0]) text += " " + word i = i+1 return text msl, uwa, tokenizer, x, y = tokenize("text_rus.txt") model = train_model(msl, uwa, x, y, 140) print("Rus: ", generate_text("Кофе со специями", tokenizer, model, msl)) msl, uwa, tokenizer, x, y = tokenize("text_eng.txt") model = train_model(msl, uwa, x, y, 140) print("Eng: ", generate_text("The spiced coffee", tokenizer, model, msl))