IIS_2023_1/faskhutdinov_idris_lab_5/main.py

55 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
def main():
# Чтение данных из датасета
data = pd.read_csv('Clean Data_pakwheels.csv')
# Выбор переменных для модели
features = ['Registration Status', 'Model Year', 'Mileage']
# Выбор лишь части значений для оптимизации работы программы
data = data.sample(frac=.1)
# Отбор нужных столбцов
df = data[features]
# Преобразование строковых значений о регистрации авто в числовые
labelencoder = LabelEncoder()
df['Registration Status'] = labelencoder.fit_transform(df['Registration Status'])
# Разделение на признаки и целевую переменную, представленную как Mileage
X = df.drop('Mileage', axis=1)
y = df['Mileage']
# Разделение данных на тренировочный и тестовый наборы
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.9, random_state=0)
# Создание и обучение логистической регрессии
model = LogisticRegression()
model.fit(X_train, y_train)
# Предсказание на тестовом наборе
y_pred = model.predict(X_test)
# Оценка качества модели
accuracy = accuracy_score(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
print(f'Точность: {accuracy}')
print(f'Классификация:\n{class_report}')
# Визуализация результатов
plt.scatter(X_test['Registration Status'], y_test, color='red', label='Actual')
plt.scatter(X_test['Registration Status'], y_pred, color='green', label='Predicted', marker='x')
plt.xlabel('Registration Status')
plt.ylabel('Mileage')
plt.legend()
plt.savefig(f"image.png")
plt.show()
main()