47 lines
2.1 KiB
Python
47 lines
2.1 KiB
Python
|
import pandas as pd
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import numpy as np
|
|||
|
from sklearn.model_selection import train_test_split
|
|||
|
from sklearn.neural_network import MLPClassifier
|
|||
|
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
|||
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
# Чтение данных из файла
|
|||
|
data = pd.read_csv('Clean Data_pakwheels.csv')
|
|||
|
# Выбор лишь части значений для оптимизации работы программы
|
|||
|
data = data.sample(frac=.1)
|
|||
|
# Выбор необходимых столбцов
|
|||
|
features = ['Model Year', 'Mileage', 'Registration Status']
|
|||
|
# Выбор данных из датасета
|
|||
|
df = data[features]
|
|||
|
# Split into features and target variable
|
|||
|
y = df['Registration Status']
|
|||
|
X = df.drop('Registration Status', axis=1)
|
|||
|
# Разделение на обучающую и тестовую выборки
|
|||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
|
|||
|
# Создание и обучение модели нейросети MLPClassifier
|
|||
|
model = MLPClassifier(random_state=0)
|
|||
|
model.fit(X_train, y_train)
|
|||
|
# Предсказания на тестовом наборе
|
|||
|
y_pred = model.predict(X_test)
|
|||
|
# Оценка модели
|
|||
|
accuracy = accuracy_score(y_test, y_pred)
|
|||
|
conf_matrix = confusion_matrix(y_test, y_pred)
|
|||
|
class_report = classification_report(y_test, y_pred)
|
|||
|
print(f'Accuracy: {accuracy}')
|
|||
|
print(f'Classification Report:\n{class_report}')
|
|||
|
|
|||
|
# Создание графика, его отображение и сохранение
|
|||
|
plt.hist(y_pred, bins=np.arange(3) - 0.5, alpha=0.75, color='Red', label='Предсказываемые')
|
|||
|
plt.hist(y_test, bins=np.arange(3) - 0.5, alpha=0.5, color='Black', label='Действительные')
|
|||
|
|
|||
|
plt.xticks([0, 1], ['Зарегистрирована', 'Не зарегистрирована'])
|
|||
|
plt.legend()
|
|||
|
plt.savefig(fname = 'image.png')
|
|||
|
plt.show()
|
|||
|
|
|||
|
|
|||
|
main()
|