IIS_2023_1/orlov_artem_lab_7/app.py
2023-12-02 13:22:07 +04:00

85 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense
from flask import Flask, request, jsonify, render_template
# Загрузка и предобработка данных
with open('your_text_file.txt', 'r', encoding='utf-8') as file:
text = file.read()
# Создание словаря символов
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for idx, char in enumerate(chars)}
# Подготовка данных для обучения
seq_length = 100
step = 3
sequences = []
next_chars = []
for i in range(0, len(text) - seq_length, step):
seq = text[i:i+seq_length]
target = text[i+seq_length]
sequences.append(seq)
next_chars.append(target)
# Преобразование данных в числовой формат
X = np.zeros((len(sequences), seq_length, len(chars)), dtype=bool)
y = np.zeros((len(sequences), len(chars)), dtype=bool)
for i, seq in enumerate(sequences):
for t, char in enumerate(seq):
X[i, t, char_to_idx[char]] = 1
y[i, char_to_idx[next_chars[i]]] = 1
# Построение более сложной модели
model = Sequential([
LSTM(256, input_shape=(seq_length, len(chars)), return_sequences=True),
LSTM(256),
Dense(len(chars), activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy')
# Увеличение количества эпох обучения
model.fit(X, y, epochs=100, batch_size=128)
# Функция для генерации текста с параметром температуры
def generate_text(seed_text, model, length=100, temperature=1.0):
generated_text = seed_text
for _ in range(length):
x = np.zeros((1, seq_length, len(chars)))
for t, char in enumerate(seed_text):
x[0, t, char_to_idx[char]] = 1
preds = model.predict(x, verbose=0)[0]
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
next_index = np.random.choice(len(chars), p=preds)
next_char = idx_to_char[next_index]
generated_text += next_char
seed_text = seed_text[1:] + next_char
return generated_text
# Создание Flask-приложения
app = Flask(__name__)
# Эндпоинт для генерации текста
@app.route('/')
def index():
return render_template('index.html')
# Эндпоинт для генерации текста
@app.route('/generate_text', methods=['POST'])
def generate_text_endpoint():
data = request.get_json()
seed_text = data.get('seed_text', '')
generated_text = generate_text(seed_text, model)
return jsonify({'generated_text': generated_text})
# Запуск Flask-сервера
if __name__ == '__main__':
app.run(port=5000)