IIS_2023_1/degtyarev_mikhail_lab_6/main.py

61 lines
2.3 KiB
Python
Raw Normal View History

2023-12-23 01:05:21 +04:00
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
# Загрузка данных
file_path = 'ds_salaries.csv'
data = pd.read_csv(file_path)
# Предварительная обработка данных
categorical_features = ['experience_level', 'employment_type', 'company_location', 'company_size']
numeric_features = ['work_year']
preprocessor = ColumnTransformer(
transformers=[
('num', StandardScaler(), numeric_features),
('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
])
# Выбор признаков
features = ['work_year', 'experience_level', 'employment_type', 'company_location', 'company_size']
X = data[features]
y = data['salary_in_usd']
# Разделение данных на обучающий и тестовый наборы
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Создание и обучение модели с использованием предварительного обработчика данных
alpha = 0.01
lasso_model = Pipeline([
('preprocessor', preprocessor),
('lasso', Lasso(alpha=alpha))
])
lasso_model.fit(X_train, y_train)
# Получение прогнозов
y_pred = lasso_model.predict(X_test)
# Оценка точности модели
accuracy = lasso_model.score(X_test, y_test)
mse = mean_squared_error(y_test, y_pred)
print(f"R^2 Score: {accuracy:.2f}")
print(f"Mean Squared Error: {mse:.2f}")
# Вывод предсказанных и фактических значений
predictions_df = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred})
print(predictions_df)
# Визуализация весов (коэффициентов) модели
coefficients = pd.Series(lasso_model.named_steps['lasso'].coef_, index=numeric_features + list(lasso_model.named_steps['preprocessor'].transformers_[1][1].get_feature_names(categorical_features)))
plt.figure(figsize=(10, 6))
coefficients.sort_values().plot(kind='barh')
plt.title('Lasso Regression Coefficients')
plt.show()