from flask import Flask, render_template, request
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import OneHotEncoder
import joblib

app = Flask(__name__)

# Загрузка данных
data = pd.read_csv('top_240_restaurants_recommended_in_los_angeles_2.csv')

# Выбор нужных столбцов
selected_columns = ['Rank', 'StarRating', 'NumberOfReviews', 'Style']
data = data[selected_columns]

# Кодирование столбца Style
encoder = OneHotEncoder(sparse=False)
encoded_styles = encoder.fit_transform(data[['Style']])
encoded_styles_df = pd.DataFrame(encoded_styles, columns=encoder.get_feature_names_out(['Style']))
data = pd.concat([data, encoded_styles_df], axis=1).drop('Style', axis=1)

# Разделение данных
X = data.drop('Rank', axis=1)
y = data['Rank']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Обучение модели
mlp_model = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000)
mlp_model.fit(X_train, y_train)

# Сохранение модели
joblib.dump(mlp_model, 'mlp_model.joblib')

# Загрузка модели
mlp_model = joblib.load('mlp_model.joblib')


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # Получение данных из формы
        input_data = {
            'StarRating': float(request.form['StarRating']),
            'NumberOfReviews': int(request.form['NumberOfReviews']),
            'Style': request.form['Style']
        }

        # Кодирование стиля
        input_style_encoded = encoder.transform([[input_data['Style']]])
        input_data.pop('Style')
        input_data.update(dict(zip(encoded_styles_df.columns, input_style_encoded[0])))

        # Преобразование данных в DataFrame
        input_df = pd.DataFrame([input_data])

        # Предсказание
        prediction = mlp_model.predict(input_df)[0]

        return render_template('index.html', prediction=prediction)


if __name__ == '__main__':
    app.run(debug=True)