2024-11-22 18:14:00 +04:00

1502 lines
58 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вариант: Экономика стран"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 369 entries, 0 to 368\n",
"Data columns (total 14 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 stock index 369 non-null object \n",
" 1 country 369 non-null object \n",
" 2 year 369 non-null float64\n",
" 3 index price 317 non-null float64\n",
" 4 log_indexprice 369 non-null float64\n",
" 5 inflationrate 326 non-null float64\n",
" 6 oil prices 369 non-null float64\n",
" 7 exchange_rate 367 non-null float64\n",
" 8 gdppercent 350 non-null float64\n",
" 9 percapitaincome 368 non-null float64\n",
" 10 unemploymentrate 348 non-null float64\n",
" 11 manufacturingoutput 278 non-null float64\n",
" 12 tradebalance 365 non-null float64\n",
" 13 USTreasury 369 non-null float64\n",
"dtypes: float64(12), object(2)\n",
"memory usage: 40.5+ KB\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"import featuretools as ft\n",
"import numpy as np\n",
"\n",
"label_encoder = LabelEncoder()\n",
"\n",
"# Функция для применения oversampling\n",
"def apply_oversampling(X, y):\n",
" oversampler = RandomOverSampler(random_state=42)\n",
" X_resampled, y_resampled = oversampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"# Функция для применения undersampling\n",
"def apply_undersampling(X, y):\n",
" undersampler = RandomUnderSampler(random_state=42)\n",
" X_resampled, y_resampled = undersampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
"):\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
"\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
"\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
"\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
"\n",
" return df_train, df_val, df_test\n",
"\n",
"data = pd.read_csv(\"../data/Economic.csv\")\n",
"data.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение бизнес-целей\n",
"1. Прогнозирование ВВП на душу населения (GDP per capita) для каждой из 9 стран на следующие 5 лет. Это позволит бизнесу и правительствам принимать обоснованные решения в области экономической политики и инвестиций.\n",
"\n",
"2. Определение факторов, наиболее сильно влияющих на ВВП на душу населения. Это поможет выявить области, требующие особого внимания для стимулирования экономического роста."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение целей технического проекта\n",
"1. Разработка модели машинного обучения, способной с высокой точностью прогнозировать ВВП на душу населения на основе исторических данных.\n",
"\n",
"2. Анализ корреляции между различными экономическими показателями и ВВП на душу населения, выявление наиболее значимых факторов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Дополнение данных"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"stock index 0\n",
"country 0\n",
"year 0\n",
"index price 52\n",
"log_indexprice 0\n",
"inflationrate 43\n",
"oil prices 0\n",
"exchange_rate 2\n",
"gdppercent 19\n",
"percapitaincome 1\n",
"unemploymentrate 21\n",
"manufacturingoutput 91\n",
"tradebalance 4\n",
"USTreasury 0\n",
"dtype: int64\n"
]
}
],
"source": [
"print(data.isnull().sum())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stock index</th>\n",
" <th>country</th>\n",
" <th>year</th>\n",
" <th>index price</th>\n",
" <th>log_indexprice</th>\n",
" <th>inflationrate</th>\n",
" <th>oil prices</th>\n",
" <th>exchange_rate</th>\n",
" <th>gdppercent</th>\n",
" <th>percapitaincome</th>\n",
" <th>unemploymentrate</th>\n",
" <th>manufacturingoutput</th>\n",
" <th>tradebalance</th>\n",
" <th>USTreasury</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NASDAQ</td>\n",
" <td>United States of America</td>\n",
" <td>1980.0</td>\n",
" <td>168.61</td>\n",
" <td>2.23</td>\n",
" <td>0.14</td>\n",
" <td>21.59</td>\n",
" <td>1.00</td>\n",
" <td>0.09</td>\n",
" <td>12575.0</td>\n",
" <td>0.07</td>\n",
" <td>1.00</td>\n",
" <td>-13.06</td>\n",
" <td>0.11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NASDAQ</td>\n",
" <td>United States of America</td>\n",
" <td>1981.0</td>\n",
" <td>203.15</td>\n",
" <td>2.31</td>\n",
" <td>0.10</td>\n",
" <td>31.77</td>\n",
" <td>1.00</td>\n",
" <td>0.12</td>\n",
" <td>13976.0</td>\n",
" <td>0.08</td>\n",
" <td>1.00</td>\n",
" <td>-12.52</td>\n",
" <td>0.14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>NASDAQ</td>\n",
" <td>United States of America</td>\n",
" <td>1982.0</td>\n",
" <td>188.98</td>\n",
" <td>2.28</td>\n",
" <td>0.06</td>\n",
" <td>28.52</td>\n",
" <td>1.00</td>\n",
" <td>0.04</td>\n",
" <td>14434.0</td>\n",
" <td>0.10</td>\n",
" <td>1.00</td>\n",
" <td>-19.97</td>\n",
" <td>0.13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>NASDAQ</td>\n",
" <td>United States of America</td>\n",
" <td>1983.0</td>\n",
" <td>285.43</td>\n",
" <td>2.46</td>\n",
" <td>0.03</td>\n",
" <td>26.19</td>\n",
" <td>1.00</td>\n",
" <td>0.09</td>\n",
" <td>15544.0</td>\n",
" <td>0.10</td>\n",
" <td>1.00</td>\n",
" <td>-51.64</td>\n",
" <td>0.11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>NASDAQ</td>\n",
" <td>United States of America</td>\n",
" <td>1984.0</td>\n",
" <td>248.89</td>\n",
" <td>2.40</td>\n",
" <td>0.04</td>\n",
" <td>25.88</td>\n",
" <td>1.00</td>\n",
" <td>0.11</td>\n",
" <td>17121.0</td>\n",
" <td>0.08</td>\n",
" <td>1.00</td>\n",
" <td>-102.73</td>\n",
" <td>0.12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>364</th>\n",
" <td>IEX 35</td>\n",
" <td>Spain</td>\n",
" <td>2016.0</td>\n",
" <td>9352.10</td>\n",
" <td>3.97</td>\n",
" <td>0.00</td>\n",
" <td>51.97</td>\n",
" <td>1.11</td>\n",
" <td>0.03</td>\n",
" <td>26523.0</td>\n",
" <td>0.20</td>\n",
" <td>139.01</td>\n",
" <td>49.16</td>\n",
" <td>0.02</td>\n",
" </tr>\n",
" <tr>\n",
" <th>365</th>\n",
" <td>IEX 35</td>\n",
" <td>Spain</td>\n",
" <td>2017.0</td>\n",
" <td>10043.90</td>\n",
" <td>4.00</td>\n",
" <td>0.02</td>\n",
" <td>57.88</td>\n",
" <td>1.13</td>\n",
" <td>0.03</td>\n",
" <td>28170.0</td>\n",
" <td>0.17</td>\n",
" <td>148.80</td>\n",
" <td>47.33</td>\n",
" <td>0.02</td>\n",
" </tr>\n",
" <tr>\n",
" <th>366</th>\n",
" <td>IEX 35</td>\n",
" <td>Spain</td>\n",
" <td>2018.0</td>\n",
" <td>8539.90</td>\n",
" <td>3.93</td>\n",
" <td>0.02</td>\n",
" <td>49.52</td>\n",
" <td>1.18</td>\n",
" <td>0.02</td>\n",
" <td>30389.0</td>\n",
" <td>0.15</td>\n",
" <td>158.33</td>\n",
" <td>38.70</td>\n",
" <td>0.03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>367</th>\n",
" <td>IEX 35</td>\n",
" <td>Spain</td>\n",
" <td>2019.0</td>\n",
" <td>9549.20</td>\n",
" <td>3.98</td>\n",
" <td>0.01</td>\n",
" <td>59.88</td>\n",
" <td>1.12</td>\n",
" <td>0.02</td>\n",
" <td>29565.0</td>\n",
" <td>0.14</td>\n",
" <td>155.49</td>\n",
" <td>41.94</td>\n",
" <td>0.02</td>\n",
" </tr>\n",
" <tr>\n",
" <th>368</th>\n",
" <td>IEX 35</td>\n",
" <td>Spain</td>\n",
" <td>2020.0</td>\n",
" <td>8073.70</td>\n",
" <td>3.91</td>\n",
" <td>0.00</td>\n",
" <td>47.02</td>\n",
" <td>1.14</td>\n",
" <td>-0.11</td>\n",
" <td>27057.0</td>\n",
" <td>0.16</td>\n",
" <td>143.05</td>\n",
" <td>19.10</td>\n",
" <td>0.01</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>369 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
" stock index country year index price \\\n",
"0 NASDAQ United States of America 1980.0 168.61 \n",
"1 NASDAQ United States of America 1981.0 203.15 \n",
"2 NASDAQ United States of America 1982.0 188.98 \n",
"3 NASDAQ United States of America 1983.0 285.43 \n",
"4 NASDAQ United States of America 1984.0 248.89 \n",
".. ... ... ... ... \n",
"364 IEX 35 Spain 2016.0 9352.10 \n",
"365 IEX 35 Spain 2017.0 10043.90 \n",
"366 IEX 35 Spain 2018.0 8539.90 \n",
"367 IEX 35 Spain 2019.0 9549.20 \n",
"368 IEX 35 Spain 2020.0 8073.70 \n",
"\n",
" log_indexprice inflationrate oil prices exchange_rate gdppercent \\\n",
"0 2.23 0.14 21.59 1.00 0.09 \n",
"1 2.31 0.10 31.77 1.00 0.12 \n",
"2 2.28 0.06 28.52 1.00 0.04 \n",
"3 2.46 0.03 26.19 1.00 0.09 \n",
"4 2.40 0.04 25.88 1.00 0.11 \n",
".. ... ... ... ... ... \n",
"364 3.97 0.00 51.97 1.11 0.03 \n",
"365 4.00 0.02 57.88 1.13 0.03 \n",
"366 3.93 0.02 49.52 1.18 0.02 \n",
"367 3.98 0.01 59.88 1.12 0.02 \n",
"368 3.91 0.00 47.02 1.14 -0.11 \n",
"\n",
" percapitaincome unemploymentrate manufacturingoutput tradebalance \\\n",
"0 12575.0 0.07 1.00 -13.06 \n",
"1 13976.0 0.08 1.00 -12.52 \n",
"2 14434.0 0.10 1.00 -19.97 \n",
"3 15544.0 0.10 1.00 -51.64 \n",
"4 17121.0 0.08 1.00 -102.73 \n",
".. ... ... ... ... \n",
"364 26523.0 0.20 139.01 49.16 \n",
"365 28170.0 0.17 148.80 47.33 \n",
"366 30389.0 0.15 158.33 38.70 \n",
"367 29565.0 0.14 155.49 41.94 \n",
"368 27057.0 0.16 143.05 19.10 \n",
"\n",
" USTreasury \n",
"0 0.11 \n",
"1 0.14 \n",
"2 0.13 \n",
"3 0.11 \n",
"4 0.12 \n",
".. ... \n",
"364 0.02 \n",
"365 0.02 \n",
"366 0.03 \n",
"367 0.02 \n",
"368 0.01 \n",
"\n",
"[369 rows x 14 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.fillna({\"index_price\": 1, \"inflationrate\": 0, \"gpdpercent\": 0, \"percapitaincome\": 100,\"unemploymentrate\": 0, \"manufacturingoutput\": 1, \"tradebalance\": -350})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разбиение данных на выборки и оценка сбалансированности выборки\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Обучающая выборка: (221, 14)\n",
"country\n",
"Spain 25\n",
"India 25\n",
"Germany 25\n",
"China 25\n",
"United Kingdom 25\n",
"Hong Kong 24\n",
"Japan 24\n",
"United States of America 24\n",
"France 24\n",
"Name: count, dtype: int64\n",
"Контрольная выборка: (74, 14)\n",
"country\n",
"United States of America 9\n",
"Japan 9\n",
"France 8\n",
"Germany 8\n",
"Hong Kong 8\n",
"Spain 8\n",
"India 8\n",
"China 8\n",
"United Kingdom 8\n",
"Name: count, dtype: int64\n",
"Тестовая выборка: (74, 14)\n",
"country\n",
"France 9\n",
"Hong Kong 9\n",
"United Kingdom 8\n",
"China 8\n",
"India 8\n",
"Spain 8\n",
"United States of America 8\n",
"Japan 8\n",
"Germany 8\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"main_data = data.copy()\n",
"\n",
"value_counts = data[\"country\"].value_counts()\n",
" \n",
"df_train, df_val, df_test = split_stratified_into_train_val_test(\n",
" data, stratify_colname=\"country\", frac_train=0.60, frac_val=0.20, frac_test=0.20)\n",
"\n",
"print(\"Обучающая выборка: \", df_train.shape)\n",
"print(df_train[\"country\"].value_counts())\n",
"\n",
"print(\"Контрольная выборка: \", df_val.shape)\n",
"print(df_val[\"country\"].value_counts())\n",
"\n",
"print(\"Тестовая выборка: \", df_test.shape)\n",
"print(df_test[\"country\"].value_counts())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Конструирование признаков"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Унитарное кодирование категориальных признаков"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stock index</th>\n",
" <th>year</th>\n",
" <th>index price</th>\n",
" <th>log_indexprice</th>\n",
" <th>inflationrate</th>\n",
" <th>oil prices</th>\n",
" <th>exchange_rate</th>\n",
" <th>gdppercent</th>\n",
" <th>percapitaincome</th>\n",
" <th>unemploymentrate</th>\n",
" <th>...</th>\n",
" <th>USTreasury</th>\n",
" <th>country_China</th>\n",
" <th>country_France</th>\n",
" <th>country_Germany</th>\n",
" <th>country_Hong Kong</th>\n",
" <th>country_India</th>\n",
" <th>country_Japan</th>\n",
" <th>country_Spain</th>\n",
" <th>country_United Kingdom</th>\n",
" <th>country_United States of America</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NASDAQ</td>\n",
" <td>1980.0</td>\n",
" <td>168.61</td>\n",
" <td>2.23</td>\n",
" <td>0.14</td>\n",
" <td>21.59</td>\n",
" <td>1.0</td>\n",
" <td>0.09</td>\n",
" <td>12575.0</td>\n",
" <td>0.07</td>\n",
" <td>...</td>\n",
" <td>0.11</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NASDAQ</td>\n",
" <td>1981.0</td>\n",
" <td>203.15</td>\n",
" <td>2.31</td>\n",
" <td>0.10</td>\n",
" <td>31.77</td>\n",
" <td>1.0</td>\n",
" <td>0.12</td>\n",
" <td>13976.0</td>\n",
" <td>0.08</td>\n",
" <td>...</td>\n",
" <td>0.14</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>NASDAQ</td>\n",
" <td>1982.0</td>\n",
" <td>188.98</td>\n",
" <td>2.28</td>\n",
" <td>0.06</td>\n",
" <td>28.52</td>\n",
" <td>1.0</td>\n",
" <td>0.04</td>\n",
" <td>14434.0</td>\n",
" <td>0.10</td>\n",
" <td>...</td>\n",
" <td>0.13</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>NASDAQ</td>\n",
" <td>1983.0</td>\n",
" <td>285.43</td>\n",
" <td>2.46</td>\n",
" <td>0.03</td>\n",
" <td>26.19</td>\n",
" <td>1.0</td>\n",
" <td>0.09</td>\n",
" <td>15544.0</td>\n",
" <td>0.10</td>\n",
" <td>...</td>\n",
" <td>0.11</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>NASDAQ</td>\n",
" <td>1984.0</td>\n",
" <td>248.89</td>\n",
" <td>2.40</td>\n",
" <td>0.04</td>\n",
" <td>25.88</td>\n",
" <td>1.0</td>\n",
" <td>0.11</td>\n",
" <td>17121.0</td>\n",
" <td>0.08</td>\n",
" <td>...</td>\n",
" <td>0.12</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 22 columns</p>\n",
"</div>"
],
"text/plain": [
" stock index year index price log_indexprice inflationrate oil prices \\\n",
"0 NASDAQ 1980.0 168.61 2.23 0.14 21.59 \n",
"1 NASDAQ 1981.0 203.15 2.31 0.10 31.77 \n",
"2 NASDAQ 1982.0 188.98 2.28 0.06 28.52 \n",
"3 NASDAQ 1983.0 285.43 2.46 0.03 26.19 \n",
"4 NASDAQ 1984.0 248.89 2.40 0.04 25.88 \n",
"\n",
" exchange_rate gdppercent percapitaincome unemploymentrate ... \\\n",
"0 1.0 0.09 12575.0 0.07 ... \n",
"1 1.0 0.12 13976.0 0.08 ... \n",
"2 1.0 0.04 14434.0 0.10 ... \n",
"3 1.0 0.09 15544.0 0.10 ... \n",
"4 1.0 0.11 17121.0 0.08 ... \n",
"\n",
" USTreasury country_China country_France country_Germany \\\n",
"0 0.11 False False False \n",
"1 0.14 False False False \n",
"2 0.13 False False False \n",
"3 0.11 False False False \n",
"4 0.12 False False False \n",
"\n",
" country_Hong Kong country_India country_Japan country_Spain \\\n",
"0 False False False False \n",
"1 False False False False \n",
"2 False False False False \n",
"3 False False False False \n",
"4 False False False False \n",
"\n",
" country_United Kingdom country_United States of America \n",
"0 False True \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
"[5 rows x 22 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Пример для кодирования стран\n",
"data = pd.get_dummies(data, columns=['country'], prefix='country')\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stock index</th>\n",
" <th>year</th>\n",
" <th>index price</th>\n",
" <th>log_indexprice</th>\n",
" <th>inflationrate</th>\n",
" <th>oil prices</th>\n",
" <th>exchange_rate</th>\n",
" <th>gdppercent</th>\n",
" <th>percapitaincome</th>\n",
" <th>unemploymentrate</th>\n",
" <th>...</th>\n",
" <th>USTreasury</th>\n",
" <th>country_China</th>\n",
" <th>country_France</th>\n",
" <th>country_Germany</th>\n",
" <th>country_Hong Kong</th>\n",
" <th>country_India</th>\n",
" <th>country_Japan</th>\n",
" <th>country_Spain</th>\n",
" <th>country_United Kingdom</th>\n",
" <th>country_United States of America</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>364</th>\n",
" <td>IEX 35</td>\n",
" <td>2016.0</td>\n",
" <td>9352.1</td>\n",
" <td>3.97</td>\n",
" <td>NaN</td>\n",
" <td>51.97</td>\n",
" <td>1.11</td>\n",
" <td>0.03</td>\n",
" <td>26523.0</td>\n",
" <td>0.20</td>\n",
" <td>...</td>\n",
" <td>0.02</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>365</th>\n",
" <td>IEX 35</td>\n",
" <td>2017.0</td>\n",
" <td>10043.9</td>\n",
" <td>4.00</td>\n",
" <td>0.02</td>\n",
" <td>57.88</td>\n",
" <td>1.13</td>\n",
" <td>0.03</td>\n",
" <td>28170.0</td>\n",
" <td>0.17</td>\n",
" <td>...</td>\n",
" <td>0.02</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>366</th>\n",
" <td>IEX 35</td>\n",
" <td>2018.0</td>\n",
" <td>8539.9</td>\n",
" <td>3.93</td>\n",
" <td>0.02</td>\n",
" <td>49.52</td>\n",
" <td>1.18</td>\n",
" <td>0.02</td>\n",
" <td>30389.0</td>\n",
" <td>0.15</td>\n",
" <td>...</td>\n",
" <td>0.03</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>367</th>\n",
" <td>IEX 35</td>\n",
" <td>2019.0</td>\n",
" <td>9549.2</td>\n",
" <td>3.98</td>\n",
" <td>0.01</td>\n",
" <td>59.88</td>\n",
" <td>1.12</td>\n",
" <td>0.02</td>\n",
" <td>29565.0</td>\n",
" <td>0.14</td>\n",
" <td>...</td>\n",
" <td>0.02</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>368</th>\n",
" <td>IEX 35</td>\n",
" <td>2020.0</td>\n",
" <td>8073.7</td>\n",
" <td>3.91</td>\n",
" <td>NaN</td>\n",
" <td>47.02</td>\n",
" <td>1.14</td>\n",
" <td>-0.11</td>\n",
" <td>27057.0</td>\n",
" <td>0.16</td>\n",
" <td>...</td>\n",
" <td>0.01</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 22 columns</p>\n",
"</div>"
],
"text/plain": [
" stock index year index price log_indexprice inflationrate \\\n",
"364 IEX 35 2016.0 9352.1 3.97 NaN \n",
"365 IEX 35 2017.0 10043.9 4.00 0.02 \n",
"366 IEX 35 2018.0 8539.9 3.93 0.02 \n",
"367 IEX 35 2019.0 9549.2 3.98 0.01 \n",
"368 IEX 35 2020.0 8073.7 3.91 NaN \n",
"\n",
" oil prices exchange_rate gdppercent percapitaincome unemploymentrate \\\n",
"364 51.97 1.11 0.03 26523.0 0.20 \n",
"365 57.88 1.13 0.03 28170.0 0.17 \n",
"366 49.52 1.18 0.02 30389.0 0.15 \n",
"367 59.88 1.12 0.02 29565.0 0.14 \n",
"368 47.02 1.14 -0.11 27057.0 0.16 \n",
"\n",
" ... USTreasury country_China country_France country_Germany \\\n",
"364 ... 0.02 False False False \n",
"365 ... 0.02 False False False \n",
"366 ... 0.03 False False False \n",
"367 ... 0.02 False False False \n",
"368 ... 0.01 False False False \n",
"\n",
" country_Hong Kong country_India country_Japan country_Spain \\\n",
"364 False False False True \n",
"365 False False False True \n",
"366 False False False True \n",
"367 False False False True \n",
"368 False False False True \n",
"\n",
" country_United Kingdom country_United States of America \n",
"364 False False \n",
"365 False False \n",
"366 False False \n",
"367 False False \n",
"368 False False \n",
"\n",
"[5 rows x 22 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.tail()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2. Дискретизация числовых признаков"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" oil prices oil_price_category\n",
"0 21.59 cheap\n",
"1 31.77 cheap\n",
"2 28.52 cheap\n",
"3 26.19 cheap\n",
"4 25.88 cheap\n",
"5 24.09 cheap\n",
"6 12.51 cheap\n",
"7 15.40 cheap\n",
"8 12.58 cheap\n",
"9 15.86 cheap\n",
"10 27.28 cheap\n",
"11 19.50 cheap\n",
"12 19.41 cheap\n",
"13 14.52 cheap\n",
"14 17.16 cheap\n",
"15 19.03 cheap\n",
"16 25.23 cheap\n",
"17 18.33 cheap\n",
"18 11.35 cheap\n",
"19 26.10 cheap\n",
"20 28.44 cheap\n",
"21 19.39 cheap\n",
"22 29.46 cheap\n",
"23 32.13 cheap\n",
"24 43.15 normal\n",
"25 59.41 normal\n",
"26 61.96 normal\n",
"27 91.69 rich\n",
"28 41.12 normal\n",
"29 74.47 rich\n"
]
}
],
"source": [
"# Пример для дискретизации года\n",
"bin = [0, 40, 70, float('inf')]\n",
"label = [\"cheap\", \"normal\", \"rich\"]\n",
"\n",
"data[\"oil_price_category\"] = pd.cut(data['oil prices'], bins=bin, labels=label)\n",
"print(data[[\"oil prices\", \"oil_price_category\"]].head(30))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3. \"Ручной\" синтез признаков"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stock index</th>\n",
" <th>year</th>\n",
" <th>index price</th>\n",
" <th>log_indexprice</th>\n",
" <th>inflationrate</th>\n",
" <th>oil prices</th>\n",
" <th>exchange_rate</th>\n",
" <th>gdppercent</th>\n",
" <th>percapitaincome</th>\n",
" <th>unemploymentrate</th>\n",
" <th>...</th>\n",
" <th>country_France</th>\n",
" <th>country_Germany</th>\n",
" <th>country_Hong Kong</th>\n",
" <th>country_India</th>\n",
" <th>country_Japan</th>\n",
" <th>country_Spain</th>\n",
" <th>country_United Kingdom</th>\n",
" <th>country_United States of America</th>\n",
" <th>oil_price_category</th>\n",
" <th>Economic_Growth</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NASDAQ</td>\n",
" <td>1980.0</td>\n",
" <td>168.61</td>\n",
" <td>2.23</td>\n",
" <td>0.14</td>\n",
" <td>21.59</td>\n",
" <td>1.0</td>\n",
" <td>0.09</td>\n",
" <td>12575.0</td>\n",
" <td>0.07</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>cheap</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NASDAQ</td>\n",
" <td>1981.0</td>\n",
" <td>203.15</td>\n",
" <td>2.31</td>\n",
" <td>0.10</td>\n",
" <td>31.77</td>\n",
" <td>1.0</td>\n",
" <td>0.12</td>\n",
" <td>13976.0</td>\n",
" <td>0.08</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>cheap</td>\n",
" <td>0.03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>NASDAQ</td>\n",
" <td>1982.0</td>\n",
" <td>188.98</td>\n",
" <td>2.28</td>\n",
" <td>0.06</td>\n",
" <td>28.52</td>\n",
" <td>1.0</td>\n",
" <td>0.04</td>\n",
" <td>14434.0</td>\n",
" <td>0.10</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>cheap</td>\n",
" <td>-0.08</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>NASDAQ</td>\n",
" <td>1983.0</td>\n",
" <td>285.43</td>\n",
" <td>2.46</td>\n",
" <td>0.03</td>\n",
" <td>26.19</td>\n",
" <td>1.0</td>\n",
" <td>0.09</td>\n",
" <td>15544.0</td>\n",
" <td>0.10</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>cheap</td>\n",
" <td>0.05</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>NASDAQ</td>\n",
" <td>1984.0</td>\n",
" <td>248.89</td>\n",
" <td>2.40</td>\n",
" <td>0.04</td>\n",
" <td>25.88</td>\n",
" <td>1.0</td>\n",
" <td>0.11</td>\n",
" <td>17121.0</td>\n",
" <td>0.08</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>cheap</td>\n",
" <td>0.02</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 24 columns</p>\n",
"</div>"
],
"text/plain": [
" stock index year index price log_indexprice inflationrate oil prices \\\n",
"0 NASDAQ 1980.0 168.61 2.23 0.14 21.59 \n",
"1 NASDAQ 1981.0 203.15 2.31 0.10 31.77 \n",
"2 NASDAQ 1982.0 188.98 2.28 0.06 28.52 \n",
"3 NASDAQ 1983.0 285.43 2.46 0.03 26.19 \n",
"4 NASDAQ 1984.0 248.89 2.40 0.04 25.88 \n",
"\n",
" exchange_rate gdppercent percapitaincome unemploymentrate ... \\\n",
"0 1.0 0.09 12575.0 0.07 ... \n",
"1 1.0 0.12 13976.0 0.08 ... \n",
"2 1.0 0.04 14434.0 0.10 ... \n",
"3 1.0 0.09 15544.0 0.10 ... \n",
"4 1.0 0.11 17121.0 0.08 ... \n",
"\n",
" country_France country_Germany country_Hong Kong country_India \\\n",
"0 False False False False \n",
"1 False False False False \n",
"2 False False False False \n",
"3 False False False False \n",
"4 False False False False \n",
"\n",
" country_Japan country_Spain country_United Kingdom \\\n",
"0 False False False \n",
"1 False False False \n",
"2 False False False \n",
"3 False False False \n",
"4 False False False \n",
"\n",
" country_United States of America oil_price_category Economic_Growth \n",
"0 True cheap NaN \n",
"1 True cheap 0.03 \n",
"2 True cheap -0.08 \n",
"3 True cheap 0.05 \n",
"4 True cheap 0.02 \n",
"\n",
"[5 rows x 24 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Пример синтеза признака \"Экономический рост\"\n",
"data['Economic_Growth'] = data['gdppercent'].diff()\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4. Масштабирование признаков"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Масштабирование признаков на основе нормировки и стандартизации в рамках данного набора данных не является необходимым"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Конструирование признаков с использованием Featuretools"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\laba_MII\\aimvenv\\Lib\\site-packages\\featuretools\\entityset\\entityset.py:1733: UserWarning: index index not found in dataframe, creating new integer column\n",
" warnings.warn(\n",
"d:\\laba_MII\\aimvenv\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n",
"d:\\laba_MII\\aimvenv\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n",
"d:\\laba_MII\\aimvenv\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n",
"d:\\laba_MII\\aimvenv\\Lib\\site-packages\\featuretools\\synthesis\\deep_feature_synthesis.py:169: UserWarning: Only one dataframe in entityset, changing max_depth to 1 since deeper features cannot be created\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Built 24 features\n",
"Elapsed: 00:00 | Progress: 100%|██████████\n",
" stock index year index price log_indexprice inflationrate \\\n",
"index \n",
"0 NASDAQ 1980.0 168.61 2.23 0.14 \n",
"1 NASDAQ 1981.0 203.15 2.31 0.10 \n",
"2 NASDAQ 1982.0 188.98 2.28 0.06 \n",
"3 NASDAQ 1983.0 285.43 2.46 0.03 \n",
"4 NASDAQ 1984.0 248.89 2.40 0.04 \n",
"5 NASDAQ 1985.0 290.25 2.46 0.04 \n",
"6 NASDAQ 1986.0 366.97 2.56 0.02 \n",
"7 NASDAQ 1987.0 402.57 2.60 0.04 \n",
"8 NASDAQ 1988.0 374.43 2.57 0.04 \n",
"9 NASDAQ 1989.0 437.80 2.64 0.05 \n",
"\n",
" oil prices exchange_rate gdppercent percapitaincome \\\n",
"index \n",
"0 21.59 1.0 0.09 12575 \n",
"1 31.77 1.0 0.12 13976 \n",
"2 28.52 1.0 0.04 14434 \n",
"3 26.19 1.0 0.09 15544 \n",
"4 25.88 1.0 0.11 17121 \n",
"5 24.09 1.0 0.07 18237 \n",
"6 12.51 1.0 0.06 19071 \n",
"7 15.40 1.0 0.06 20039 \n",
"8 12.58 1.0 0.08 21417 \n",
"9 15.86 1.0 0.08 22857 \n",
"\n",
" unemploymentrate ... country_France country_Germany \\\n",
"index ... \n",
"0 0.07 ... False False \n",
"1 0.08 ... False False \n",
"2 0.10 ... False False \n",
"3 0.10 ... False False \n",
"4 0.08 ... False False \n",
"5 0.07 ... False False \n",
"6 0.07 ... False False \n",
"7 0.06 ... False False \n",
"8 0.05 ... False False \n",
"9 0.05 ... False False \n",
"\n",
" country_Hong Kong country_India country_Japan country_Spain \\\n",
"index \n",
"0 False False False False \n",
"1 False False False False \n",
"2 False False False False \n",
"3 False False False False \n",
"4 False False False False \n",
"5 False False False False \n",
"6 False False False False \n",
"7 False False False False \n",
"8 False False False False \n",
"9 False False False False \n",
"\n",
" country_United Kingdom country_United States of America \\\n",
"index \n",
"0 False True \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"5 False True \n",
"6 False True \n",
"7 False True \n",
"8 False True \n",
"9 False True \n",
"\n",
" oil_price_category Economic_Growth \n",
"index \n",
"0 cheap NaN \n",
"1 cheap 0.03 \n",
"2 cheap -0.08 \n",
"3 cheap 0.05 \n",
"4 cheap 0.02 \n",
"5 cheap -0.04 \n",
"6 cheap -0.01 \n",
"7 cheap 0.00 \n",
"8 cheap 0.02 \n",
"9 cheap 0.00 \n",
"\n",
"[10 rows x 24 columns]\n"
]
}
],
"source": [
"# Определение сущностей\n",
"es = ft.EntitySet(id='economic')\n",
"es = es.add_dataframe(dataframe_name=\"dataEconomic\", dataframe=data, index='index', make_index=False)\n",
"\n",
"# Автоматическое конструирование признаков\n",
"feature_matrix, feature_defs = ft.dfs(entityset=es, target_dataframe_name=\"dataEconomic\", max_depth=2, verbose=1, n_jobs=1)\n",
"print(feature_matrix.head(10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка качества наборов признаков"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Все наборы признаков имеют неплохую предсказательную способность, высокую скорость вычисления, высокую надежность, при правильности их предварительной обработки, высокую корреляцию и цельность. Данные могут быть использованы для дальнейшего улучшения модели и принятия обоснованных бизнес-решений в области экономики."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimvenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}