import pygrib as pgrb
import numpy as np
import pandas as pd
from config_cray import Config as cfg              #!!
from pathlib import Path
import time
import os
import sys
from utils import uv2df
from concurrent.futures import ThreadPoolExecutor

t0 = time.time() 
DATEHH = sys.argv[1]
GRIBS_DIR = Path(cfg.path_gribs.format(DATEHH))                 
OUTPUT_DIR = Path(f'{cfg.path_output}/{DATEHH[:-2]}/{DATEHH[-2:]}')
os.makedirs(OUTPUT_DIR / 'profiles', exist_ok=True)

NLAT, NLON = cfg.nlat, cfg.nlon
datestamp = pd.Timestamp(DATEHH[:-2])

# ── Cities ────────────────────────────────────────────────────────────────────
cities     = pd.read_csv(cfg.cities_csv)
n_cities   = len(cities)
i_idx      = cities['i'].values           # (n_cities,) 0-based grid column
j_idx      = cities['j'].values           # (n_cities,) 0-based grid row
utc_offset = cities['utc_offset'].values  # (n_cities,) hours ahead of UTC

# ── Hours for reading gribs ─────────────────────────────────────────
hrs_grbs = [f"{h//24}{(h%24)//10}{(h%24)%10}" for h in range(0, 73, 3)] # 000, 003, 006 ... 221, 300
NH = len(hrs_grbs)
print(f'hrs_grbs: {hrs_grbs}, {NH} hours')

grbs_msgs_profiles = cfg.grbs_msgs_profiles

#####################################################################################################################################################
### ──────────────────────────────────────────────────  Reading raw data from gribs ──────────────────────────────────────────────────────────────###
#####################################################################################################################################################

def smooth_at_cities(field, j_idx, i_idx):
    out = np.zeros(len(j_idx), dtype=np.float32)
    for dj in (-1, 0, 1):
        jj = np.clip(j_idx + dj, 0, field.shape[0] - 1)
        for di in (-1, 0, 1):
            ii = np.clip(i_idx + di, 0, field.shape[1] - 1)
            out += field[jj, ii]
    return out / np.float32(9)

def read_messages_from_grib(fl: str, 
                            descs: list[dict]): # --> {var: np.array} 
    def prcs(fld: np.array, 
             smth: bool):
        if smth: 
            return smooth_at_cities(fld, j_idx, i_idx)
        else: 
            return fld[j_idx, i_idx]
      
    grbs = pgrb.open(fl) 
    out = {i['name']:np.empty(n_cities, dtype=np.float32) if len(i['messages']) == 1 else np.empty((len(i['levels']), n_cities), dtype=np.float32) for i in descs}
    for desc in descs: 
        msgs, var, levs, smth, name  = desc.values()
        if len(msgs) == 1: 
            msg, lev = msgs[0], levs[0]
            fld = grbs.message(msg)
            assert (fld.shortName == var) and (fld.level == lev), f"Wrong message idx for {fl}: {desc}"
            fld = prcs(fld.values, smth)
            out[name] = fld
        else: 
            gthr = []
            for msg, lev in zip(msgs, levs):
                fld = grbs.message(msg)
                assert (fld.shortName == var) and (fld.level == lev), f"Wrong message idx for {fl}: {desc}"
                fld = prcs(fld.values, smth)
                gthr.append(fld) 
            out[name] = np.stack(gthr) # (levs, n_cities)

    grbs.close()
    return out 

raw_vars = {
    desc['name']: np.empty((NH, n_cities), dtype=np.float32) if len(desc['messages']) == 1
                else np.empty((NH, len(desc['messages']), n_cities), dtype=np.float32)
    for descs in grbs_msgs_profiles.values() for desc in descs
} # {var: np.array((NH [,n_levs], n_cities))} 

def read_gribs_one_task(args):
    print(f'Reading {args}')
    hr_grb, sfx, descs = args
    h = hrs_grbs.index(hr_grb)
    out = read_messages_from_grib(str(GRIBS_DIR / f'igfi0{hr_grb}0000{sfx}.ENA6km.grb'), descs)
    for k, v in out.items():
        raw_vars[k][h] = v
    print(f'{args} done')

tasks = [(hr_grb, sfx, descs) for hr_grb in hrs_grbs for sfx, descs in grbs_msgs_profiles.items()]
with ThreadPoolExecutor(max_workers=32) as ex:
    list(ex.map(read_gribs_one_task, tasks))

print(f"Reading gribs done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

# ── Unit conversions ──────────────────────────────────────────────────────────
raw_vars['prmsl'] *= np.float32(0.01)      # Pa → hPa
raw_vars['2t'] -= np.float32(273.16)
raw_vars['t_925'] -= np.float32(273.16)
raw_vars['t'] -= np.float32(273.16)

# ── Precipitation deaccumulation ──────────────────────────────────────────────

raw_vars['tp'][0] = 0.0
raw_vars['tp_raw'][0] = 0.0                                                     # why? 
raw_vars['tp'] = np.diff(raw_vars['tp'], axis=0, prepend=0)                      
raw_vars['tp_raw'] = np.diff(raw_vars['tp_raw'], axis=0, prepend=0)
prec = np.maximum(raw_vars['tp_raw'], raw_vars['tp'])
del raw_vars['tp_raw'], raw_vars['tp']
prec[prec < 0.001] = 0.0


# ── Wind ──────────────────────────────────────────────────────────────────────
V10m, dd10m = uv2df(raw_vars['10u'], raw_vars['10v'])
del raw_vars['10u'], raw_vars['10v']

V925_raw = np.hypot(raw_vars['u_925_raw'], raw_vars['v_925_raw'])
del raw_vars['u_925_raw'], raw_vars['v_925_raw']
V925_smt, dd_925 = uv2df(raw_vars['u_925'], raw_vars['v_925'])
del raw_vars['u_925'], raw_vars['v_925']

V925 = np.maximum(V925_raw, V925_smt)
del V925_raw, V925_smt

print(f"Processing vars done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

# ── Interpolation ───────────────────────────────────────────────────────────
# t_mlevs/u_mlevs/v_mlevs shape: (NH, n_levs, n_cities) → transpose to (n_levs, NH, n_cities)
t_mlevs = raw_vars['t'].transpose(1, 0, 2)
u_mlevs = raw_vars['u'].transpose(1, 0, 2)
v_mlevs = raw_vars['v'].transpose(1, 0, 2)
del raw_vars['t'], raw_vars['u'], raw_vars['v']

z_targets  = np.array([100., 200., 300., 500.], dtype=np.float32)
n_z        = len(z_targets)
n_levs_ml  = t_mlevs.shape[0]

with np.load(cfg.constfields) as f:
    zm_cities = f['zm'][:, j_idx, i_idx].copy()  # (9, n_cities): real heights of model levels 66-74 above ground, ascending

print(f'Интерполяция {zm_cities[:, 0]} --> {z_targets}')

# ── Precompute interpolation weights once (zm_cities is constant across hours) ──
nc_idx = np.arange(n_cities)
k_idx  = np.empty((n_z, n_cities), dtype=np.int32)   # lower bracket level index per target height per city
w_arr  = np.empty((n_z, n_cities), dtype=np.float32) # linear interp weight per target height per city

for nz, zt in enumerate(z_targets):
    # for each city: count how many model levels are below zt → gives lower bracket index
    k = np.clip(np.sum(zm_cities <= zt, axis=0) - 1, 0, n_levs_ml - 2)  # (n_cities,): index of the first model level lower then target alt
    zm_k  = zm_cities[k,   nc_idx]  # (n_cities,) height of lower bracket level for each city
    zm_k1 = zm_cities[k+1, nc_idx]  # (n_cities,) height of upper bracket level for each city
    k_idx[nz] = k
    w_arr[nz]  = (zt - zm_k) / (zm_k1 - zm_k)  # weight: 0 = at lower level, 1 = at upper level

# ── Apply weights — vectorized over all cities and hours ─────────────────────
tz = np.empty((n_z, NH, n_cities), dtype=np.float32)
uz = np.empty((n_z, NH, n_cities), dtype=np.float32)
vz = np.empty((n_z, NH, n_cities), dtype=np.float32)

for nz in range(n_z):
    k = k_idx[nz]           # (n_cities,): lower bracket level index for this target height
    w = w_arr[nz, :, None]  # (n_cities, 1): weight, None adds axis to broadcast over NH
    # fancy indexing: t_mlevs[k, :, nc_idx] → (n_cities, NH), then .T → (NH, n_cities)
    tz[nz] = ((1-w) * t_mlevs[k, :, nc_idx] + w * t_mlevs[k+1, :, nc_idx]).T
    uz[nz] = ((1-w) * u_mlevs[k, :, nc_idx] + w * u_mlevs[k+1, :, nc_idx]).T
    vz[nz] = ((1-w) * v_mlevs[k, :, nc_idx] + w * v_mlevs[k+1, :, nc_idx]).T

del t_mlevs, u_mlevs, v_mlevs

Vz, ddz = uv2df(uz, vz)
del uz, vz

print(f"Interpolation done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

# ── Output ───────────────────────────────────────────────────────────

cities_ru = cities['name_rus'].values
cities_utc_offsets = cities['utc_offset'].values
fos = cities['fo_name'].values

hours_utc = list(range(0, 73, 3))
int_cols = ['P', 'V10m', 'dd10m',
            'V100m', 'dd100m', 'V200m', 'dd200m', 'V300m', 'dd300m',
            'V500m', 'dd500m', 'V925', 'dd925']

for nc in range(n_cities): 
    fo = fos[nc]
    os.makedirs(OUTPUT_DIR / 'profiles' / fo, exist_ok=True)
    file_name = OUTPUT_DIR / 'profiles' / fo / f'{cities_ru[nc]}.xlsx'
    
    table = pd.DataFrame({'utc_time':   [f'{h%24:02d}' for h in hours_utc],
                          'local_time': [f'{(h%24 + cities_utc_offsets[nc]):02d}' for h in hours_utc],
                          'P':          raw_vars['prmsl'][:, nc],
                          't2m':        raw_vars['2t'][:, nc],
                          'V10m':       V10m[:, nc],
                          'dd10m':      dd10m[:, nc],
                          't100m':      tz[0, :, nc],
                          'V100m':      Vz[0, :, nc],
                          'dd100m':     ddz[0, :, nc],
                          't200m':      tz[1, :, nc],
                          'V200m':      Vz[1, :, nc],
                          'dd200m':     ddz[1, :, nc],
                          't300m':      tz[2, :, nc],
                          'V300m':      Vz[2, :, nc],
                          'dd300m':     ddz[2, :, nc],
                          't500m':      tz[3, :, nc],
                          'V500m':      Vz[3, :, nc],
                          'dd500m':     ddz[3, :, nc],
                          't925':       raw_vars['t_925'][:, nc],
                          'V925':       V925[:, nc],
                          'dd925':      dd_925[:, nc],
                          'prec':       prec[:, nc]})

    table = table.round(0).astype({c: np.int32 for c in int_cols})
    table.to_excel(file_name, float_format='%.1f', index=False)
