IIS_2023_1/simonov_nikita_lab_1/lab1-web.py

71 lines
2.4 KiB
Python
Raw Normal View History

2023-11-05 13:31:50 +04:00
from flask import Flask, render_template
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score
import io
from flask import Response
import matplotlib
import base64
app = Flask(__name__)
matplotlib.use('Agg')
matplotlib.rcParams['figure.max_open_warning'] = 0
# Создаем данные
moon_dataset = make_moons(noise=0.3, random_state=None)
X, y = moon_dataset
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)
# Создаем модели
models = {
"Линейная регрессия": LogisticRegression(),
"Полиномиальная регрессия": make_pipeline(PolynomialFeatures(degree=4), LogisticRegression()),
"Гребневая полиномиальная регрессия": make_pipeline(PolynomialFeatures(degree=4), LogisticRegression(penalty='l2', C=1.0))
}
background_color1 = '#CE5A57'
background_color2 = '#78A5A3'
data_color1 = 'red'
data_color2 = 'green'
# Обучаем и оцениваем модели
model_results = {}
for name, model in models.items():
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
model_results[name] = {
'accuracy': accuracy,
'X_test': X_test,
'y_test': y_test,
'model': model
}
@app.route('/')
def index():
plot_images = {}
for model_name, results in model_results.items():
fig, ax = plt.subplots(figsize=(8, 6))
cm_data = ListedColormap([data_color1, data_color2])
scatter = ax.scatter(results['X_test'][:, 0], results['X_test'][:, 1], c=results['model'].predict(results['X_test']), cmap=cm_data, alpha=0.6)
ax.set_xticks(())
ax.set_yticks(())
ax.set_title(model_name)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plot_images[model_name] = base64.b64encode(buf.read()).decode('utf-8')
return render_template('index.html', model_results=model_results, plot_images=plot_images)
if __name__ == '__main__':
app.run(threaded=True)