40 lines
902 B
Python
40 lines
902 B
Python
|
import numpy as np
|
||
|
|
||
|
|
||
|
def read_text(filename):
|
||
|
return open(filename, 'r', encoding='UTF-8').read().lower()
|
||
|
|
||
|
|
||
|
def peek_random_seed(data):
|
||
|
return np.random.randint(0, len(data) - 1)
|
||
|
|
||
|
|
||
|
def prepare_text(text, ids, seq_length):
|
||
|
dataX = []
|
||
|
dataY = []
|
||
|
|
||
|
for i in range(0, len(text) - seq_length, 1):
|
||
|
seq_in = text[i:i + seq_length]
|
||
|
seq_out = text[i + seq_length]
|
||
|
|
||
|
dataX.append([ids[char] for char in seq_in])
|
||
|
dataY.append(ids[seq_out])
|
||
|
|
||
|
return dataX, dataY
|
||
|
|
||
|
|
||
|
def generate_text(model, pattern, chars, n_vocab, prediction_len):
|
||
|
text = ""
|
||
|
|
||
|
for i in range(prediction_len):
|
||
|
x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
|
||
|
prediction = model.predict(x, verbose=0)
|
||
|
index = np.argmax(prediction)
|
||
|
|
||
|
text += chars[index]
|
||
|
|
||
|
pattern.append(index)
|
||
|
pattern = pattern[1:len(pattern)]
|
||
|
|
||
|
return text
|