IIS_2023_1/degtyarev_mikhail_lab_3/main.py

54 lines
2.3 KiB
Python
Raw Normal View History

2023-12-03 15:08:35 +04:00
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.preprocessing import LabelEncoder
# Загрузка данных и удаление столбца 'Unnamed: 0'
data = pd.read_csv('ds_salaries.csv').drop('Unnamed: 0', axis=1)
# Определение признаков и целевой переменной
features = ['experience_level', 'employment_type', 'company_location', 'company_size']
target = 'job_title'
# Преобразование категориальных признаков в числовые
label_encoder = LabelEncoder()
for feature in features:
data[feature] = label_encoder.fit_transform(data[feature])
# Преобразование целевой переменной в числовой формат
data[target] = label_encoder.fit_transform(data[target])
# Разделение данных на обучающий (99%) и тестовый (1%) наборы
train_data, test_data = train_test_split(data, test_size=0.01, random_state=42)
# Создание модели дерева решений
model = DecisionTreeClassifier(random_state=42)
# Обучение модели
model.fit(train_data[features], train_data[target])
# Предсказание на тестовом наборе
predictions = model.predict(test_data[features])
# Обратное преобразование числовых предсказаний в строковый формат
predictions_str = label_encoder.inverse_transform(predictions)
# Оценка точности модели на тестовом наборе
accuracy = accuracy_score(test_data[target], predictions)
print(f'Accuracy: {accuracy * 100:.2f}%')
# Средняя квадратичная ошибка в процентах
mse = mean_squared_error(test_data[target], predictions)
print(f'Mean Squared Error: {mse:.2f}%')
feature_importance = model.feature_importances_
feature_importance_dict = dict(zip(features, feature_importance))
sorted_feature_importance = sorted(feature_importance_dict.items(), key=lambda x: x[1], reverse=True)
print("Feature Importance:")
for feature, importance in sorted_feature_importance:
print(f"{feature}: {importance}")
print("First 5 rows of test data:")
print(test_data.head())