71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
|
from flask import Flask, render_template
|
||
|
import numpy as np
|
||
|
from matplotlib import pyplot as plt
|
||
|
from matplotlib.colors import ListedColormap
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.datasets import make_moons
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
from sklearn.preprocessing import PolynomialFeatures
|
||
|
from sklearn.pipeline import make_pipeline
|
||
|
from sklearn.metrics import accuracy_score
|
||
|
import io
|
||
|
from flask import Response
|
||
|
import matplotlib
|
||
|
import base64
|
||
|
|
||
|
app = Flask(__name__)
|
||
|
matplotlib.use('Agg')
|
||
|
matplotlib.rcParams['figure.max_open_warning'] = 0
|
||
|
|
||
|
# Создаем данные
|
||
|
moon_dataset = make_moons(noise=0.3, random_state=None)
|
||
|
X, y = moon_dataset
|
||
|
X = StandardScaler().fit_transform(X)
|
||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)
|
||
|
|
||
|
# Создаем модели
|
||
|
models = {
|
||
|
"Линейная регрессия": LogisticRegression(),
|
||
|
"Полиномиальная регрессия": make_pipeline(PolynomialFeatures(degree=4), LogisticRegression()),
|
||
|
"Гребневая полиномиальная регрессия": make_pipeline(PolynomialFeatures(degree=4), LogisticRegression(penalty='l2', C=1.0))
|
||
|
}
|
||
|
|
||
|
background_color1 = '#CE5A57'
|
||
|
background_color2 = '#78A5A3'
|
||
|
data_color1 = 'red'
|
||
|
data_color2 = 'green'
|
||
|
|
||
|
# Обучаем и оцениваем модели
|
||
|
model_results = {}
|
||
|
for name, model in models.items():
|
||
|
model.fit(X_train, y_train)
|
||
|
y_pred = model.predict(X_test)
|
||
|
accuracy = accuracy_score(y_test, y_pred)
|
||
|
model_results[name] = {
|
||
|
'accuracy': accuracy,
|
||
|
'X_test': X_test,
|
||
|
'y_test': y_test,
|
||
|
'model': model
|
||
|
}
|
||
|
|
||
|
@app.route('/')
|
||
|
def index():
|
||
|
plot_images = {}
|
||
|
for model_name, results in model_results.items():
|
||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||
|
cm_data = ListedColormap([data_color1, data_color2])
|
||
|
scatter = ax.scatter(results['X_test'][:, 0], results['X_test'][:, 1], c=results['model'].predict(results['X_test']), cmap=cm_data, alpha=0.6)
|
||
|
ax.set_xticks(())
|
||
|
ax.set_yticks(())
|
||
|
ax.set_title(model_name)
|
||
|
buf = io.BytesIO()
|
||
|
plt.savefig(buf, format='png')
|
||
|
buf.seek(0)
|
||
|
plot_images[model_name] = base64.b64encode(buf.read()).decode('utf-8')
|
||
|
|
||
|
return render_template('index.html', model_results=model_results, plot_images=plot_images)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
app.run(threaded=True)
|