88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
from keras.models import Sequential
|
||
|
from keras.layers import Embedding, LSTM, Dense
|
||
|
from flask import Flask, request, jsonify, render_template
|
||
|
|
||
|
def load_and_preprocess_data(file_path, seq_length=100, step=3):
|
||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||
|
text = file.read()
|
||
|
|
||
|
chars = sorted(set(text))
|
||
|
char_to_idx = {char: idx for idx, char in enumerate(chars)}
|
||
|
idx_to_char = {idx: char for idx, char in enumerate(chars)}
|
||
|
|
||
|
sequences, next_chars = [], []
|
||
|
|
||
|
for i in range(0, len(text) - seq_length, step):
|
||
|
seq = text[i:i + seq_length]
|
||
|
target = text[i + seq_length]
|
||
|
sequences.append(seq)
|
||
|
next_chars.append(target)
|
||
|
|
||
|
X = np.zeros((len(sequences), seq_length), dtype=np.int32)
|
||
|
y = np.zeros((len(sequences),), dtype=np.int32)
|
||
|
|
||
|
for i, seq in enumerate(sequences):
|
||
|
for t, char in enumerate(seq):
|
||
|
X[i, t] = char_to_idx[char]
|
||
|
y[i] = char_to_idx[next_chars[i]]
|
||
|
|
||
|
return X, y, len(chars), char_to_idx, idx_to_char
|
||
|
|
||
|
def build_model(seq_length, num_chars):
|
||
|
model = Sequential([
|
||
|
Embedding(num_chars, 50, input_length=seq_length),
|
||
|
LSTM(128),
|
||
|
Dense(num_chars, activation='softmax')
|
||
|
])
|
||
|
|
||
|
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
|
||
|
|
||
|
return model
|
||
|
|
||
|
def train_model(model, X, y, epochs=100, batch_size=128):
|
||
|
model.fit(X, y, epochs=epochs, batch_size=batch_size)
|
||
|
|
||
|
def generate_text(seed_text, model, seq_length, char_to_idx, idx_to_char, length=100, temperature=1.0):
|
||
|
generated_text = seed_text
|
||
|
for _ in range(length):
|
||
|
x = np.zeros((1, seq_length), dtype=np.int32)
|
||
|
for t, char in enumerate(seed_text):
|
||
|
x[0, t] = char_to_idx[char]
|
||
|
preds = model.predict(x, verbose=0)[0][-1]
|
||
|
preds = np.log(preds) / temperature
|
||
|
exp_preds = np.exp(preds)
|
||
|
preds = exp_preds / np.sum(exp_preds)
|
||
|
next_index = np.random.choice(len(preds), p=preds)
|
||
|
next_char = idx_to_char[next_index]
|
||
|
generated_text += next_char
|
||
|
seed_text = seed_text[1:] + next_char
|
||
|
return generated_text
|
||
|
|
||
|
def create_flask_app(model, seq_length, char_to_idx, idx_to_char):
|
||
|
app = Flask(__name__)
|
||
|
|
||
|
@app.route('/')
|
||
|
def index():
|
||
|
return render_template('index.html')
|
||
|
|
||
|
@app.route('/generate_text', methods=['POST'])
|
||
|
def generate_text_endpoint():
|
||
|
data = request.get_json()
|
||
|
seed_text = data.get('seed_text', '')
|
||
|
generated_text = generate_text(seed_text, model, seq_length, char_to_idx, idx_to_char)
|
||
|
return jsonify({'generated_text': generated_text})
|
||
|
|
||
|
return app
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
file_path = 'your_text_file.txt'
|
||
|
X, y, num_chars, char_to_idx, idx_to_char = load_and_preprocess_data(file_path)
|
||
|
seq_length = 100
|
||
|
model = build_model(seq_length, num_chars)
|
||
|
train_model(model, X, y, epochs=100, batch_size=128)
|
||
|
|
||
|
flask_app = create_flask_app(model, seq_length, char_to_idx, idx_to_char)
|
||
|
flask_app.run(port=5000)
|