from sklearn.impute import SimpleImputer, MissingIndicator
from sklearn.pipeline import FeatureUnion, make_pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
import pandas as pd
import random as rand
import numpy as np
from matplotlib import pyplot as plt


def rank_to_dict(ranks, names, n_features):
    ranks = np.abs(ranks)
    minmax = MinMaxScaler()
    ranks = minmax.fit_transform(np.array(ranks).reshape(len(ranks), 1)).ravel()
    ranks = map(lambda x: round(x, 2), ranks)
    return dict(zip(names, ranks))


def part_one():
    print('Titanic data analysis\n')
    data = pd.read_csv('titanic_data.csv', index_col='PassengerId')
    x = data[['Pclass', 'Name', 'Sex']]
    y = data[['Survived']]

    names = pd.DataFrame(TfidfVectorizer().fit_transform(x['Name']).toarray())
    col_names = names[names.columns[1:]].apply(lambda el: sum(el.dropna().astype(float)), axis=1)
    col_names.index = np.arange(1, len(col_names) + 1)
    col_sexes = []

    for index, row in x.iterrows():
        if row['Sex'] == 'male':
            col_sexes.append(1)
        else:
            col_sexes.append(0)

    x = x.drop(columns=['Sex', 'Name'])
    x['Sex'] = col_sexes
    x['Name'] = col_names

    dtc = DecisionTreeClassifier(random_state=rand.randint(0, 250))
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.05, random_state=rand.randint(0, 250))
    dtc.fit(x_train, y_train)
    print('model score: ' + str(dtc.score(x_test, y_test)))
    res = dict(zip(['Pclass', 'Sex', 'Name'], dtc.feature_importances_))
    print('feature importances: ' + str(res))


def part_two():
    print('\n---------------------------------------------------------------------------\nSberbank data analysis\n')
    data = pd.read_csv('sberbank_data.csv', index_col='id')
    x = data.drop(columns='price_doc')
    y = data[['price_doc']]

    x = x.replace(
        ['NA', 'no', 'yes', 'Investment', 'OwnerOccupier', 'poor', 'satisfactory', 'no data', 'good', 'excellent'],
        [0, 0, 1, 0, 1, -1, 0, 0, 1, 2])
    x.fillna(0, inplace=True)

    names = pd.DataFrame(TfidfVectorizer().fit_transform(x['sub_area']).toarray())
    col_area = names[names.columns[1:]].apply(lambda el: sum(el.dropna().astype(float)), axis=1)
    col_area.index = np.arange(1, len(col_area) + 1)
    col_date = []

    for val in x['timestamp']:
        col_date.append(val.split('-', 1)[0])

    x = x.drop(columns=['sub_area', 'timestamp'])
    x['sub_area'] = col_area
    x['timestamp'] = col_date

    col_price = []
    for val in y['price_doc']:
        if val < 1500000:
            col_price.append('low')
        elif val < 3000000:
            col_price.append('medium')
        elif val < 5500000:
            col_price.append('high')
        elif val < 10000000:
            col_price.append('premium')
        else:
            col_price.append('oligarch')

    y = pd.DataFrame(col_price)

    transformer = FeatureUnion(
        transformer_list=[
            ('features', SimpleImputer(strategy='mean')),
            ('indicators', MissingIndicator())])

    dtr = make_pipeline(transformer, DecisionTreeClassifier(random_state=rand.randint(0, 250)))
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.01, random_state=rand.randint(0, 250))
    dtr.fit(x_train, y_train)

    features = list(x.columns)
    print('model score: ' + str(dtr.score(x_test, y_test)))

    res = sorted(dict(zip(features, dtr.steps[-1][1].feature_importances_)).items(),
                 key=lambda el: el[1], reverse=True)

    view_y = []
    view_x = []

    flag = 0
    print('feature importances:')
    for val in res:
        if flag == 8:
            break
        print(val[0]+" - "+str(val[1]))
        view_y.append(val[0])
        view_x.append(val[1])
        flag = flag + 1

    plt.figure(1, figsize=(16, 9))
    plt.bar(view_y, view_x)
    plt.show()


def start():
    part_one()
    part_two()


start()