from datetime import datetime
from pathlib import Path

import metpy.calc
import numpy as np
import requests
import torch
import xarray as xr
from aurora import AuroraSmall, Batch, Metadata
from metpy.units import units


def get_download_paths(date):
    """Создает список путей для загрузки данных."""
    download_path = Path("~/downloads/hres_0.1").expanduser()
    downloads = {}
    var_nums = {
        "2t": "167", "10u": "165", "10v": "166", "msl": "151", "t": "130",
        "u": "131", "v": "132", "q": "133", "z": "129", "slt": "043", "lsm": "172",
    }
    for v in ["2t", "10u", "10v", "msl", "z", "slt", "lsm"]:
        downloads[download_path / date.strftime(f"surf_{v}_%Y-%m-%d.grib")] = (
            f"https://data.rda.ucar.edu/ds113.1/"
            f"ec.oper.an.sfc/{date.year}{date.month:02d}/ec.oper.an.sfc.128_{var_nums[v]}_{v}."
            f"regn1280sc.{date.year}{date.month:02d}{date.day:02d}.grb"
        )
    for v in ["z", "t", "u", "v", "q"]:
        for hour in [0, 6, 12, 18]:
            prefix = "uv" if v in {"u", "v"} else "sc"
            downloads[download_path / date.strftime(f"atmos_{v}_%Y-%m-%d_{hour:02d}.grib")] = (
                f"https://data.rda.ucar.edu/ds113.1/"
                f"ec.oper.an.pl/{date.year}{date.month:02d}/ec.oper.an.pl.128_{var_nums[v]}_{v}."
                f"regn1280{prefix}.{date.year}{date.month:02d}{date.day:02d}{hour:02d}.grb"
            )
    return downloads, download_path


def download_data(downloads):
    """Скачивает файлы, если они отсутствуют в целевой директории."""
    for target, source in downloads.items():
        if not target.exists():
            print(f"Downloading {source}")
            target.parent.mkdir(parents=True, exist_ok=True)
            response = requests.get(source)
            response.raise_for_status()
            with open(target, "wb") as f:
                f.write(response.content)
            print("Downloads finished!")


def load_surf(v, v_in_file, download_path, date):
    """Загружает переменные поверхностного уровня или статические переменные."""
    ds = xr.open_dataset(download_path / date.strftime(f"surf_{v}_%Y-%m-%d.grib"), engine="cfgrib")
    data = ds[v_in_file].values[:2]
    data = data[None]
    return torch.from_numpy(data)


def load_atmos(v, download_path, date, levels):
    """Загружает атмосферные переменные для заданных уровней давления."""
    ds_00 = xr.open_dataset(
        download_path / date.strftime(f"atmos_{v}_%Y-%m-%d_00.grib"), engine="cfgrib"
    )
    ds_06 = xr.open_dataset(
        download_path / date.strftime(f"atmos_{v}_%Y-%m-%d_06.grib"), engine="cfgrib"
    )
    ds_00 = ds_00[v].sel(isobaricInhPa=list(levels))
    ds_06 = ds_06[v].sel(isobaricInhPa=list(levels))
    data = np.stack((ds_00.values, ds_06.values), axis=0)
    data = data[None]
    return torch.from_numpy(data)


def create_batch(date, levels, downloads, download_path):
    """Создает объект Batch с данными для модели."""
    ds = xr.open_dataset(next(iter(downloads.keys())), engine="cfgrib")
    batch = Batch(
        surf_vars={
            "2t": load_surf("2t", "t2m", download_path, date),
            "10u": load_surf("10u", "u10", download_path, date),
            "10v": load_surf("10v", "v10", download_path, date),
            "msl": load_surf("msl", "msl", download_path, date),
        },
        static_vars={
            "z": load_surf("z", "z", download_path, date)[0, 0],
            "slt": load_surf("slt", "slt", download_path, date)[0, 0],
            "lsm": load_surf("lsm", "lsm", download_path, date)[0, 0],
        },
        atmos_vars={
            "t": load_atmos("t", download_path, date, levels),
            "u": load_atmos("u", download_path, date, levels),
            "v": load_atmos("v", download_path, date, levels),
            "q": load_atmos("q", download_path, date, levels),
            "z": load_atmos("z", download_path, date, levels),
        },
        metadata=Metadata(
            lat=torch.from_numpy(ds.latitude.values),
            lon=torch.from_numpy(ds.longitude.values),
            time=(date.replace(hour=6),),
            atmos_levels=levels,
        ),
    )
    return batch.regrid(res=0.1)


def create_batch_random(levels: tuple[int], date: tuple):
    """Создает объект Batch с рандомными данными для модели."""
    return Batch(
        surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "msl")},
        static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt")},
        atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
        metadata=Metadata(
            lat=torch.linspace(90, -90, 17),
            lon=torch.linspace(0, 360, 32 + 1)[:-1],
            time=date,
            atmos_levels=levels,
        ),
    )


def run_model(batch):
    """Инициализирует модель AuroraSmall и выполняет предсказание."""
    model = AuroraSmall()
    model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
    model.eval()
    model = model.to("cpu")
    with torch.inference_mode():
        prediction = model.forward(batch)
    return prediction


def get_wind_speed_and_direction(prediction, batch: Batch, lat: float, lon: float):
    target_lat = lat
    target_lon = lon

    lat_idx = torch.abs(batch.metadata.lat - target_lat).argmin()
    lon_idx = torch.abs(batch.metadata.lon - target_lon).argmin()

    u_values = prediction.atmos_vars["u"][:, :, :, lat_idx, lon_idx]
    v_values = prediction.atmos_vars["v"][:, :, :, lat_idx, lon_idx]
    wind_speeds=[]
    wind_directions=[]
    for i in range(u_values.numel()):
        u_scalar = u_values.view(-1)[i].item()  # Разворачиваем тензор в одномерный и берем элемент
        v_scalar = v_values.view(-1)[i].item()

        print("u value:", u_scalar)
        print("v value:", v_scalar)

        u_with_units = u_scalar * units("m/s")
        v_with_units = v_scalar * units("m/s")

        # Рассчитайте направление и скорость ветра
        wind_dir = metpy.calc.wind_direction(u_with_units, v_with_units)
        wind_speed = metpy.calc.wind_speed(u_with_units, v_with_units)

        wind_speeds.append(wind_speed.magnitude.item())
        wind_directions.append(wind_dir.magnitude.item())

    return wind_speeds,wind_directions


def wind_direction_to_text(wind_dir_deg):
    directions = [
        "север", "северо-восток", "восток", "юго-восток",
        "юг", "юго-запад", "запад", "северо-запад"
    ]
    idx = int((wind_dir_deg + 22.5) // 45) % 8
    return directions[idx]


def main():
    levels = (100,)

    date1 = datetime(2024, 11, 27, 12)
    date2 = datetime(2024, 11, 28, 12)
    date_tuple = (date1, date2,)
    # downloads, download_path = get_download_paths(date)
    # download_data(downloads)  # Скачиваем данные, если их нет
    # batch_actual = create_batch(date, levels, downloads, download_path)
    batch_actual = create_batch_random(levels, date_tuple)
    prediction_actual = run_model(batch_actual)
    wind_speed_and_direction = get_wind_speed_and_direction(prediction_actual, batch_actual, 50, 20)
    return wind_speed_and_direction


if __name__ == "__main__":
    main()
    print("Prediction completed!")