163 lines
6.0 KiB
Plaintext
163 lines
6.0 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn.feature_selection import mutual_info_regression\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from imblearn.over_sampling import ADASYN\n",
|
|||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка датасета с мобильными телефонами\n",
|
|||
|
"df = pd.read_csv(\"data/mobile_prices.csv\")\n",
|
|||
|
"print(df.columns)\n",
|
|||
|
"\n",
|
|||
|
"# Кодируем все строковые столбцы в числовые\n",
|
|||
|
"label_encoders = {}\n",
|
|||
|
"for col in df.select_dtypes(include=[\"object\"]).columns:\n",
|
|||
|
" le = LabelEncoder()\n",
|
|||
|
" df[col] = le.fit_transform(df[col])\n",
|
|||
|
" label_encoders[col] = le\n",
|
|||
|
"\n",
|
|||
|
"# Проверка на пропуски и \"зашумленные\" столбцы\n",
|
|||
|
"noisy_features = []\n",
|
|||
|
"for col in df.columns:\n",
|
|||
|
" if df[col].isnull().sum() / len(df) > 0.1: # Если более 10% пропусков\n",
|
|||
|
" noisy_features.append(col)\n",
|
|||
|
"print(f\"Зашумленные столбцы: {noisy_features}\")\n",
|
|||
|
"\n",
|
|||
|
"# Проверка на смещение\n",
|
|||
|
"skewness = df.skew()\n",
|
|||
|
"print(f\"Смещение: {skewness}\")\n",
|
|||
|
"\n",
|
|||
|
"skewed_features = skewness[abs(skewness) > 1].index.tolist()\n",
|
|||
|
"print(f\"Сильно смещенные столбцы: {skewed_features}\")\n",
|
|||
|
"\n",
|
|||
|
"# Поиск выбросов\n",
|
|||
|
"for col in df.select_dtypes(include=[\"number\"]).columns:\n",
|
|||
|
" Q1 = df[col].quantile(0.25)\n",
|
|||
|
" Q3 = df[col].quantile(0.75)\n",
|
|||
|
" IQR = Q3 - Q1\n",
|
|||
|
" lower_bound = Q1 - 1.5 * IQR\n",
|
|||
|
" upper_bound = Q3 + 1.5 * IQR\n",
|
|||
|
" outliers = df[col][(df[col] < lower_bound) | (df[col] > upper_bound)]\n",
|
|||
|
" print(f\"Выбросы в столбце '{col}':\\n{outliers}\\n\")\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация выбросов\n",
|
|||
|
"numeric_cols = df.select_dtypes(include=[\"number\"]).columns\n",
|
|||
|
"\n",
|
|||
|
"plt.figure(figsize=(12, 8))\n",
|
|||
|
"for i, col in enumerate(numeric_cols, 1):\n",
|
|||
|
" plt.subplot(len(numeric_cols) // 3 + 1, 3, i)\n",
|
|||
|
" sns.boxplot(data=df, x=col)\n",
|
|||
|
" plt.title(f\"Boxplot for {col}\")\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Логарифмирование признака 'Battery'\n",
|
|||
|
"df[\"log_Battery\"] = np.log(df[\"Battery\"] + 1)\n",
|
|||
|
"\n",
|
|||
|
"# Заполнение пропусков\n",
|
|||
|
"df[\"Battery\"] = df[\"Battery\"].fillna(df[\"Battery\"].mean())\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Функция для разбиения на train/val/test\n",
|
|||
|
"def split_stratified_into_train_val_test(\n",
|
|||
|
" df_input,\n",
|
|||
|
" stratify_colname=\"y\",\n",
|
|||
|
" frac_train=0.6,\n",
|
|||
|
" frac_val=0.15,\n",
|
|||
|
" frac_test=0.25,\n",
|
|||
|
" random_state=None,\n",
|
|||
|
"):\n",
|
|||
|
" if frac_train + frac_val + frac_test != 1.0:\n",
|
|||
|
" raise ValueError(\n",
|
|||
|
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
|||
|
" % (frac_train, frac_val, frac_test)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
"\n",
|
|||
|
" X = df_input\n",
|
|||
|
" y = df_input[[stratify_colname]]\n",
|
|||
|
"\n",
|
|||
|
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
|||
|
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
|||
|
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
|||
|
" df_temp,\n",
|
|||
|
" y_temp,\n",
|
|||
|
" stratify=y_temp,\n",
|
|||
|
" test_size=relative_frac_test,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
|||
|
"\n",
|
|||
|
" return df_train, df_val, df_test\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Разбиение на train/val/test\n",
|
|||
|
"data = df[[\"Ram\", \"Price\", \"company\"]].copy()\n",
|
|||
|
"\n",
|
|||
|
"print(\"@data\", data)\n",
|
|||
|
"\n",
|
|||
|
"data = data.groupby(\"company\").filter(\n",
|
|||
|
" lambda x: len(x) > 4\n",
|
|||
|
") # убираем классы с одним элементом\n",
|
|||
|
"\n",
|
|||
|
"df_train, df_val, df_test = split_stratified_into_train_val_test(\n",
|
|||
|
" data, stratify_colname=\"company\", frac_train=0.60, frac_val=0.20, frac_test=0.20\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(\"Обучающая выборка: \", df_train.shape)\n",
|
|||
|
"print(df_train[\"Ram\"].value_counts())\n",
|
|||
|
"\n",
|
|||
|
"print(\"Контрольная выборка: \", df_val.shape)\n",
|
|||
|
"print(df_val[\"Ram\"].value_counts())\n",
|
|||
|
"\n",
|
|||
|
"print(\"Тестовая выборка: \", df_test.shape)\n",
|
|||
|
"print(df_test[\"Ram\"].value_counts())\n",
|
|||
|
"\n",
|
|||
|
"# # Применение ADASYN для oversampling\n",
|
|||
|
"# ada = ADASYN(n_neighbors=2)\n",
|
|||
|
"# X_resampled, y_resampled = ada.fit_resample(df_train, df_train[\"company\"])\n",
|
|||
|
"# df_train_adasyn = pd.DataFrame(X_resampled)\n",
|
|||
|
"\n",
|
|||
|
"# print(\"Обучающая выборка после oversampling: \", df_train_adasyn.shape)\n",
|
|||
|
"# print(df_train_adasyn[\"Ram\"].value_counts())"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|