import psycopg
import pandas as pd

class PostgresClient:
    def __init__(self, dbname, user, password, host, port):
        self.connection = psycopg.connect(
            dbname=dbname,
            user=user,
            password=password,
            host=host,
            port=port
        )
        self.init_db()

    def init_db(self):
        with self.connection.cursor() as cur:
            cur.execute("""
            CREATE TABLE IF NOT EXISTS load_parameters (
                id SERIAL PRIMARY KEY,
                load NUMERIC NOT NULL UNIQUE,
                primary_air_consumption NUMERIC NOT NULL,
                secondary_air_consumption NUMERIC NOT NULL,
                gas_inlet_consumption NUMERIC NOT NULL
            );

            CREATE TABLE IF NOT EXISTS recycling_parameters (
                id SERIAL PRIMARY KEY,
                load_id INTEGER NOT NULL,
                recycling_level NUMERIC NOT NULL,
                CO2 NUMERIC NOT NULL,
                N2 NUMERIC NOT NULL,
                H2O NUMERIC NOT NULL,
                O2 NUMERIC NOT NULL,
                UNIQUE(load_id, recycling_level),
                FOREIGN KEY (load_id) REFERENCES load_parameters(id) ON DELETE CASCADE
            );

            CREATE TABLE IF NOT EXISTS experiment_parameters (
                id SERIAL PRIMARY KEY,
                outer_blades_count INTEGER NOT NULL,
                outer_blades_length NUMERIC NOT NULL,
                outer_blades_angle NUMERIC NOT NULL,
                middle_blades_count INTEGER NOT NULL,
                load_id INTEGER NOT NULL,
                recycling_id INTEGER NOT NULL,
                experiment_hash CHAR(64) NOT NULL UNIQUE,
                FOREIGN KEY (load_id) REFERENCES load_parameters(id)  ON DELETE CASCADE,
                FOREIGN KEY (recycling_id) REFERENCES recycling_parameters(id)  ON DELETE CASCADE
            );
            
            CREATE TABLE IF NOT EXISTS experiment_data (
                    id BIGSERIAL PRIMARY KEY,
                    Direction DOUBLE PRECISION,
                    Temperature DOUBLE PRECISION,
                    NOx DOUBLE PRECISION,
                    CO2 DOUBLE PRECISION,
                    CO DOUBLE PRECISION,
                    file_id CHAR(64) NOT NULL
            );
            """)
            self.connection.commit()

    def insert_load_parameters(self, load_parameters):
        with self.connection.cursor() as cur:
            cur.execute("SELECT id FROM load_parameters WHERE load = %s", (load_parameters['load'],))
            load_id = cur.fetchone()
            if load_id is None:
                cur.execute("""
                    INSERT INTO load_parameters (load, primary_air_consumption, secondary_air_consumption, gas_inlet_consumption)
                    VALUES (%s, %s, %s, %s)
                    RETURNING id;
                """, (load_parameters['load'], load_parameters['primary_air_consumption'],
                      load_parameters['secondary_air_consumption'], load_parameters['gas_inlet_consumption']))
                load_id = cur.fetchone()[0]
            else:
                load_id = load_id[0]
            self.connection.commit()
            return load_id

    def insert_recycling_parameters(self, recycling_parameters, load_id):
        with self.connection.cursor() as cur:
            cur.execute("SELECT id FROM recycling_parameters WHERE load_id = %s AND recycling_level = %s",
                        (load_id, recycling_parameters['recycling_level']))
            recycling_id = cur.fetchone()
            if recycling_id is None:
                cur.execute("""
                    INSERT INTO recycling_parameters (load_id, recycling_level, CO2, N2, H2O, O2)
                    VALUES (%s, %s, %s, %s, %s, %s)
                    RETURNING id;
                """, (load_id, recycling_parameters['recycling_level'], recycling_parameters['CO2'],
                      recycling_parameters['N2'], recycling_parameters['H2O'], recycling_parameters['O2']))
                recycling_id = cur.fetchone()[0]
            else:
                recycling_id = recycling_id[0]
            self.connection.commit()
            return recycling_id

    def insert_experiment_parameters(self, experiment_parameters, load_id, recycling_id, file_id):
        with self.connection.cursor() as cur:
            cur.execute("SELECT id FROM experiment_parameters WHERE experiment_hash = %s", (file_id,))
            experiment_id = cur.fetchone()
            if experiment_id is None:
                cur.execute("""
                    INSERT INTO experiment_parameters (outer_blades_count, outer_blades_length, outer_blades_angle, middle_blades_count, load_id, recycling_id, experiment_hash)
                    VALUES (%s, %s, %s, %s, %s, %s, %s);
                """, (experiment_parameters['outer_blades_count'], experiment_parameters['outer_blades_length'],
                      experiment_parameters['outer_blades_angle'], experiment_parameters['middle_blades_count'], load_id,
                      recycling_id, file_id))
            self.connection.commit()

    def get_load_parameters(self, load):
        with self.connection.cursor() as cur:
            cur.execute("SELECT * FROM load_parameters WHERE load = %s", (load,))
            row = cur.fetchone()
            if row:
                return {
                    'load': row[1],
                    'primary_air_consumption': row[2],
                    'secondary_air_consumption': row[3],
                    'gas_inlet_consumption': row[4]
                }
        return None

    def get_recycling_parameters(self, load, recycling_level):
        with self.connection.cursor() as cur:
            cur.execute("""
                SELECT rp.* FROM recycling_parameters rp
                JOIN load_parameters lp ON rp.load_id = lp.id
                WHERE lp.load = %s AND rp.recycling_level = %s
            """, (load, recycling_level))
            row = cur.fetchone()
            if row:
                return {
                    'load': load,
                    'recycling_level': row[2],
                    'CO2': row[3],
                    'N2': row[4],
                    'H2O': row[5],
                    'O2': row[6]
                }
        return None

    def get_experiment_parameters(self, experiment_hash):
        with self.connection.cursor() as cur:
            cur.execute("SELECT * FROM experiment_parameters WHERE experiment_hash = %s", (experiment_hash,))
            row = cur.fetchone()
            if row:
                load_params = self.get_load_parameters(row[5])
                recycling_params = self.get_recycling_parameters(load_params['load'], row[6])
                return {
                    'outer_blades_count': row[1],
                    'outer_blades_length': row[2],
                    'outer_blades_angle': row[3],
                    'middle_blades_count': row[4],
                    'load': load_params['load'],
                    'recycling': recycling_params['recycling_level'],
                    'experiment_hash': row[7]
                }
        return None

    def get_experiments(self):
        # query = """
        #         SELECT
        #             ep.experiment_hash AS file_id,
        #             ep.outer_blades_count,
        #             ep.outer_blades_length,
        #             ep.outer_blades_angle,
        #             ep.middle_blades_count,
        #             lp.primary_air_consumption,
        #             lp.secondary_air_consumption,
        #             lp.gas_inlet_consumption,
        #             rp.n2,
        #             rp.o2,
        #             rp.h2o,
        #             rp.co2
        #         FROM
        #             experiment_parameters ep
        #         JOIN
        #             load_parameters lp ON ep.load_id = lp.id
        #         JOIN
        #             recycling_parameters rp ON ep.recycling_id = rp.id
        #     """
        query = """
                SELECT
                    ep.experiment_hash AS file_id,
                    ep.outer_blades_count,
                    ep.outer_blades_length,
                    ep.outer_blades_angle,
                    ep.middle_blades_count,
                    lp.load,
                    rp.recycling_level
                FROM
                    experiment_parameters ep
                JOIN
                    load_parameters lp ON ep.load_id = lp.id
                JOIN
                    recycling_parameters rp ON ep.recycling_id = rp.id
            """
        with self.connection.cursor() as cursor:
            cursor.execute(query)
            data = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

        df = pd.DataFrame(data, columns=columns)
        return df

    def save_csv_to_postgres(self, csv_path, file_id):

        try:
            # Прочитать файл и добавить хэш как новую колонку
            df = pd.read_csv(csv_path)

            first_col = df.columns[0]
            df = df[[first_col] + [col for col in df.columns if "Line Section: Direction [-1,0,0] (m)" not in col]]

            # Переименовать колонки
            rename_dict = {
                "Line Section: Direction [-1,0,0] (m)": "Direction",
                "Line Section: Temperature (K)": "Temperature",
                "Line Section: Mass Fraction of Nitrogen Oxide Emission": "NOx",
                "Line Section: Mass Fraction of CO2": "CO2",
                "Line Section: Mass Fraction of CO": "CO"
            }
            df.rename(columns=rename_dict, inplace=True)

            df['file_id'] = file_id

            with self.connection.cursor() as cur:
                cur.execute("SELECT file_id FROM experiment_data WHERE file_id = %s", (file_id,))
                row = cur.fetchone()
                if row:
                    cur.execute("DELETE FROM experiment_data WHERE file_id = %s", (file_id,))
                    self.connection.commit()

                # Вставка новых данных из DataFrame в таблицу
                insert_query = '''
                    INSERT INTO experiment_data (Direction, Temperature, NOx, CO2, CO, file_id)
                    VALUES (%s, %s, %s, %s, %s, %s)
                    '''
                data_to_insert = df.to_records(index=False).tolist()
                cur.executemany(insert_query, data_to_insert)
                self.connection.commit()

                # Закрытие соединения
                cur.close()

            return "Success"

        except Exception as e:
            return f"Failed: {str(e)}"

    def get_data(self):
        query = """
            WITH max_temp AS (
                SELECT
                    file_id,
                    temperature AS max_temperature,
                    direction AS direction_for_max_temp,
                    ROW_NUMBER() OVER (PARTITION BY file_id ORDER BY temperature DESC) AS temp_rank
                FROM
                    experiment_data
            ),
            max_co2 AS (
                SELECT
                    file_id,
                    co2 AS max_co2,
                    direction AS direction_for_max_co2,
                    ROW_NUMBER() OVER (PARTITION BY file_id ORDER BY co2 DESC) AS co2_rank
                FROM
                    experiment_data
            ),
            max_co AS (
                SELECT
                    file_id,
                    co AS max_co,
                    direction AS direction_for_max_co,
                    ROW_NUMBER() OVER (PARTITION BY file_id ORDER BY co DESC) AS co_rank
                FROM
                    experiment_data
            ),
            max_nox AS (
                SELECT
                    file_id,
                    nox AS max_nox,
                    direction AS direction_for_max_nox,
                    ROW_NUMBER() OVER (PARTITION BY file_id ORDER BY nox DESC) AS nox_rank
                FROM
                    experiment_data
            )
            SELECT
                t.file_id,
                t.direction_for_max_temp,
                t.max_temperature,
                cx.direction_for_max_co2,
                cx.max_co2,
                c.direction_for_max_co,
                c.max_co,
                n.direction_for_max_nox,
                n.max_nox
            FROM
                (SELECT * FROM max_temp WHERE temp_rank = 1) t
            LEFT JOIN
                (SELECT * FROM max_nox WHERE nox_rank = 1) n ON t.file_id = n.file_id
            LEFT JOIN
                (SELECT * FROM max_co2 WHERE co2_rank = 1) cx ON t.file_id = cx.file_id
            LEFT JOIN
                (SELECT * FROM max_co WHERE co_rank = 1) c ON t.file_id = c.file_id;
            """
        with self.connection.cursor() as cursor:
            cursor.execute(query)
            data = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

        df = pd.DataFrame(data, columns=columns)
        return df

    def close(self):
        self.connection.close()


# Основной скрипт
# def main():
#     # Данные
#     experiment_parameters = {
#         'outer_blades_count': 24,
#         'outer_blades_length': 74.0,
#         'outer_blades_angle': 65.0,
#         'middle_blades_count': 18,
#         'load': 315.0,
#         'recycling': 8.0,
#     }
#
#     load_parameters = {
#         'load': 315.0,
#         'primary_air_consumption': 15.2239,
#         'secondary_air_consumption': 63.9876,
#         'gas_inlet_consumption': 0.8648
#     }
#
#     recycling_parameters = {
#         'load': 315.0,
#         'recycling_level': 8.0,
#         'CO2': 0.04,
#         'N2': 0.70,
#         'H2O': 0.06,
#         'O2': 0.20
#     }
#
#     # Инициализация базы данных
#     db = PostgresClient(
#         dbname="your_db_name",
#         user="your_db_user",
#         password="your_db_password",
#         host="your_db_host",
#         port="your_db_port"
#     )
#
#     try:
#
#         # Извлечение и печать данных
#         retrieved_experiment = db.get_experiment_parameters(experiment_parameters['experiment_hash'])
#         print("Retrieved experiment parameters:", retrieved_experiment)
#
#         retrieved_load = db.get_load_parameters(load_parameters['load'])
#         print("Retrieved load parameters:", retrieved_load)
#
#         retrieved_recycling = db.get_recycling_parameters(recycling_parameters['load'],
#                                                           recycling_parameters['recycling_level'])
#         print("Retrieved recycling parameters:", retrieved_recycling)
#     finally:
#         db.close()
#
#
# if __name__ == "__main__":
#     main()