85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
import numpy as np
|
||
import tensorflow as tf
|
||
from keras.models import Sequential
|
||
from keras.layers import Embedding, LSTM, Dense
|
||
from flask import Flask, request, jsonify, render_template
|
||
|
||
# Загрузка и предобработка данных
|
||
with open('your_text_file.txt', 'r', encoding='utf-8') as file:
|
||
text = file.read()
|
||
|
||
# Создание словаря символов
|
||
chars = sorted(list(set(text)))
|
||
char_to_idx = {char: idx for idx, char in enumerate(chars)}
|
||
idx_to_char = {idx: char for idx, char in enumerate(chars)}
|
||
|
||
# Подготовка данных для обучения
|
||
seq_length = 100
|
||
step = 3
|
||
sequences = []
|
||
next_chars = []
|
||
|
||
for i in range(0, len(text) - seq_length, step):
|
||
seq = text[i:i+seq_length]
|
||
target = text[i+seq_length]
|
||
sequences.append(seq)
|
||
next_chars.append(target)
|
||
|
||
# Преобразование данных в числовой формат
|
||
X = np.zeros((len(sequences), seq_length, len(chars)), dtype=bool)
|
||
y = np.zeros((len(sequences), len(chars)), dtype=bool)
|
||
|
||
for i, seq in enumerate(sequences):
|
||
for t, char in enumerate(seq):
|
||
X[i, t, char_to_idx[char]] = 1
|
||
y[i, char_to_idx[next_chars[i]]] = 1
|
||
|
||
# Построение более сложной модели
|
||
model = Sequential([
|
||
LSTM(256, input_shape=(seq_length, len(chars)), return_sequences=True),
|
||
LSTM(256),
|
||
Dense(len(chars), activation='softmax')
|
||
])
|
||
|
||
model.compile(optimizer='adam', loss='categorical_crossentropy')
|
||
|
||
# Увеличение количества эпох обучения
|
||
model.fit(X, y, epochs=100, batch_size=128)
|
||
|
||
# Функция для генерации текста с параметром температуры
|
||
def generate_text(seed_text, model, length=100, temperature=1.0):
|
||
generated_text = seed_text
|
||
for _ in range(length):
|
||
x = np.zeros((1, seq_length, len(chars)))
|
||
for t, char in enumerate(seed_text):
|
||
x[0, t, char_to_idx[char]] = 1
|
||
preds = model.predict(x, verbose=0)[0]
|
||
preds = np.log(preds) / temperature
|
||
exp_preds = np.exp(preds)
|
||
preds = exp_preds / np.sum(exp_preds)
|
||
next_index = np.random.choice(len(chars), p=preds)
|
||
next_char = idx_to_char[next_index]
|
||
generated_text += next_char
|
||
seed_text = seed_text[1:] + next_char
|
||
return generated_text
|
||
|
||
# Создание Flask-приложения
|
||
app = Flask(__name__)
|
||
|
||
# Эндпоинт для генерации текста
|
||
@app.route('/')
|
||
def index():
|
||
return render_template('index.html')
|
||
|
||
# Эндпоинт для генерации текста
|
||
@app.route('/generate_text', methods=['POST'])
|
||
def generate_text_endpoint():
|
||
data = request.get_json()
|
||
seed_text = data.get('seed_text', '')
|
||
generated_text = generate_text(seed_text, model)
|
||
return jsonify({'generated_text': generated_text})
|
||
|
||
# Запуск Flask-сервера
|
||
if __name__ == '__main__':
|
||
app.run(port=5000)
|