import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt

# Загрузка данных
file_path = 'ds_salaries.csv'
data = pd.read_csv(file_path)

# Предварительная обработка данных
categorical_features = ['experience_level', 'employment_type', 'company_location', 'company_size']
numeric_features = ['work_year']

preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numeric_features),
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
    ])

# Выбор признаков
features = ['work_year', 'experience_level', 'employment_type', 'company_location', 'company_size']
X = data[features]
y = data['salary_in_usd']

# Разделение данных на обучающий и тестовый наборы
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Создание и обучение модели с использованием предварительного обработчика данных
alpha = 0.01
lasso_model = Pipeline([
    ('preprocessor', preprocessor),
    ('lasso', Lasso(alpha=alpha))
])

lasso_model.fit(X_train, y_train)

# Получение прогнозов
y_pred = lasso_model.predict(X_test)

# Оценка точности модели
accuracy = lasso_model.score(X_test, y_test)
mse = mean_squared_error(y_test, y_pred)

print(f"R^2 Score: {accuracy:.2f}")
print(f"Mean Squared Error: {mse:.2f}")

# Вывод предсказанных и фактических значений
predictions_df = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred})
print(predictions_df)

# Визуализация весов (коэффициентов) модели
coefficients = pd.Series(lasso_model.named_steps['lasso'].coef_, index=numeric_features + list(lasso_model.named_steps['preprocessor'].transformers_[1][1].get_feature_names(categorical_features)))
plt.figure(figsize=(10, 6))
coefficients.sort_values().plot(kind='barh')
plt.title('Lasso Regression Coefficients')
plt.show()