import numpy as np
from pathlib import Path
from config_cray import Config as cfg                    #!!
from scipy.ndimage import uniform_filter, maximum_filter
import pandas as pd
import os
import pygrib as pgrb
import sys
from utils import uv2df, nint, select_period_bounds
import time  
from concurrent.futures import ThreadPoolExecutor

t0 = time.time() 
DATEHH = sys.argv[1]
GRIBS_DIR = Path(cfg.path_gribs.format(DATEHH))                               
OUTPUT_DIR = Path(f'{cfg.path_output}/{DATEHH[:-2]}/{DATEHH[-2:]}')
os.makedirs(OUTPUT_DIR / 'tables', exist_ok=True)

NLAT, NLON, NH = cfg.nlat, cfg.nlon, cfg.nhours   # 1000, 2000, 49
NCELLS = NLAT * NLON                               # 2 000 000
datestamp = pd.Timestamp(DATEHH[:-2])

save_grib_vars_to_npz=True

#####################################################################################################################################################
### ──────────────────────────────────────────────────  Reading raw data from gribs ──────────────────────────────────────────────────────────────###
#####################################################################################################################################################

hrs_grbs = [f"{h//24}{(h%24)//10}{(h%24)%10}" for h in range(NH)] # 000, 001...023, 100, 101...123,200
grbs_msgs = cfg.grbs_msgs_mpz
adv_lvl = 850 if 2 <= datestamp.month <= 9 else 925 # Advection field t925 or t850 defined by season
print(f'Advection field: {adv_lvl}  (n_month={datestamp.month})')

grb_vars = {k:np.empty((NH, NLAT, NLON), dtype=np.float32) for k in ['2t', f't_{adv_lvl}', 'u_70', 'v_70', 't_70']}
grb_vars['tp'] = np.zeros((NH, NLAT, NLON), dtype=np.float32)

def read_gribs(hr_grb): 
    h = hrs_grbs.index(hr_grb)
    print(f'Reading gribs {hr_grb}')

    for sfx, msgs in grbs_msgs.items(): 
        fl = str(GRIBS_DIR / f'igfi0{hr_grb}0000{sfx}.ENA6km.grb')
        grbs = pgrb.open(fl)

        for idx, desc in msgs.items():
            if (sfx == 'a_pl') and (desc[1] != adv_lvl):
                continue
            if len(desc) > 1:
                name = f'{desc[0]}_{desc[1]}'
                msg = grbs.message(idx)
                assert (msg.shortName == desc[0]) and (msg.level == desc[1]), f"Wrong message idx for {sfx}: {desc}"
            else:
                name = f'{desc[0]}'
                if (name == 'tp') and (hr_grb == '000'): 
                    continue 
                msg = grbs.message(idx)
                assert msg.shortName == desc[0], f"Wrong message idx for {sfx}: {desc}"
            grb_vars[name][h] = msg.values
        grbs.close()


if Path(OUTPUT_DIR / "grib_vars.npz").is_file(): 
    save_grib_vars_to_npz = False
    with np.load(OUTPUT_DIR / "grib_vars.npz") as npz:
        for k in grb_vars: 
            grb_vars[k] = np.asarray(npz[k], dtype=np.float32)
else:
    with ThreadPoolExecutor(max_workers=32) as ex: 
        list(ex.map(read_gribs, hrs_grbs))

print(f"Reading gribs done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

if save_grib_vars_to_npz: 
    np.savez_compressed(OUTPUT_DIR / f'grib_vars.npz', 
                        **grb_vars) 

for var in grb_vars: 
    print(f'{var}: {grb_vars[var].mean(axis = (-1,-2))}')

#####################################################################################################################################################
### ───────────────────────────────────────── Vars processing - stacking, smoothing, etc. ────────────────────────────────────────────────────────###
#####################################################################################################################################################

def smooth_3_3(arr):                                                                                
    a = np.pad(arr, ((0,0),(1,1),(1,1)), mode='edge')
    return (
        a[:, 0:-2, 0:-2] + a[:, 0:-2, 1:-1] + a[:, 0:-2, 2:] +
        a[:, 1:-1, 0:-2] + a[:, 1:-1, 1:-1] + a[:, 1:-1, 2:] +
        a[:, 2:,   0:-2] + a[:, 2:,   1:-1] + a[:, 2:,   2:]
        ) / np.float32(9)

#Or lise that: 
#def box_filter_spatial(arr, k=3):
#   p = k // 2  # padding size: for k=3, p=1

    # Pad edges by 1 pixel on each spatial side, replicate border values
    # (T,H,W) --> (T, H+2, W+2)
#   a = np.pad(arr, ((0,0),(p,p),(p,p)), mode='edge')

    # Cumulative sum along rows, then columns
    # Each cell now contains sum of all elements above-left of it
#   cs = np.cumsum(np.cumsum(a, axis=1), axis=2)

    # Extract box sums using the 2D prefix sum formula:
    # sum of rectangle = bottom_right - bottom_left - top_right + top_left
    # Then divide by k*k to get the mean
#   return (cs[:,k:,k:] - cs[:,k:,:-k] - cs[:,:-k,k:] + cs[:,:-k,:-k]) / k**2

#t2m_grid = uniform_filter(np.stack(t2m_list,  axis=0) - 273.16, size=(1, 3, 3), mode='nearest') # stacking hours --> smoothing 3*3 --> (49,1000,2000)
#tz_grid = uniform_filter(np.stack(tz_list,  axis=0) - 273.16, size=(1, 3, 3), mode='nearest') # stacking hours --> smoothing 3*3 --> (49,1000,2000)

Pr_grid = np.diff(grb_vars['tp'], axis=0, prepend=0) # deaccumulate h0 = 0.0, h1 = h1, h2 = h2-h1 ... 
del grb_vars['tp']
#prec_grid  = np.maximum(prec_grid, smooth_3_3(prec_grid)) # maximum between grid cell value and 3*3 window 
np.maximum(Pr_grid, smooth_3_3(Pr_grid), out=Pr_grid) # inplace
Pr_grid[Pr_grid < cfg.pr_min] = 0.0
#prec_grid_bc = prec_grid[None] # (1, 49, 1000, 2000) — broadcasted to shape-match pertrubation variants

# ── Advection ────
tadv_grid = grb_vars[f't_{adv_lvl}']
#Adv_grid = np.diff(tadv_grid, axis=0, prepend=tadv_grid[[0]])                                                              # not used at all, Adv for periods derived later
#Adv_grid = maximum_filter(Adv_grid, size=(1, 3, 3), mode='nearest')
del grb_vars[f't_{adv_lvl}']

# ── Vz, Str pertrubated ────
Vz_grid, _ = uv2df(grb_vars['u_70'], grb_vars['v_70'])
del grb_vars[f'u_70'], grb_vars[f'v_70'], _
#ddz_grid = nint(ddz_grid)
Vz_grid = smooth_3_3(Vz_grid) # smoothed 3*3
np.maximum(Vz_grid, 1.0, out=Vz_grid) # 1.0 where Vz < 1.0, inplace
Vz_grid = nint(Vz_grid) # rounding to int, nint rounding exactly like in Fortran
#Vz_grid_pert = nint(np.maximum(Vz_grid[None] + dv_arr,   1.0)) # --> (9, 1, 1, 1) ensemble Vz variants --> where < 1 = 1 --> round to int

# loading constants
with np.load(cfg.constfields) as npz:
    dz = np.asarray(npz['dz'], dtype=np.float32) # layer thickness grid
    lons = np.asarray(npz['icon6_geo_lon'], dtype=np.float32)
    lats = np.asarray(npz['icon6_geo_lat'], dtype=np.float32)

Str_grid = grb_vars['2t'] - grb_vars['t_70']
del grb_vars['2t'], grb_vars['t_70']
Str_grid /= dz
Str_grid = smooth_3_3(Str_grid)
#Str_grid_pert  = Str_grid[None] + dStr_arr # --> (9, 1, 1, 1) ensemble Str variants 

print(f"Processing done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

#####################################################################################################################################################
### ──────────────────────────────────────────────────────── MPRZ decision tree ──────────────────────────────────────────────────────────────────###
#####################################################################################################################################################

# ── Cities metadata ─────
cities     = pd.read_csv(cfg.cities_csv)
i_idx      = cities['i'].values            # (n_cities,)  0-based column (lon)
j_idx      = cities['j'].values            # (n_cities,)  0-based row    (lat)
utc_offset = cities['utc_offset'].values   # (n_cities,)
n_cities   = len(cities)
fo_names = cities['fo_name'].values
cities_name_ru = cities['name_rus'].values
cities_name_en = cities['name'].values

# ── MPZ decision tree ─────
def get_typeMPZ_hour(Vz_grid, Str_grid, Pr_grid):
    conditions = [
        Vz_grid > 6,                                                                                                                                ## Check
        Str_grid  >= 0.9,
        Pr_grid    >= 1.0,
        (Str_grid >= 0.5) & (Pr_grid     != 0),
        (Str_grid >= 0.5) & (Vz_grid >= 5),
        Str_grid >= 0.5,
        (Str_grid >= 0.1) & (Pr_grid != 0),
        (Str_grid >= 0.1) & (Vz_grid >= 5),
        (Str_grid >= 0.1) & (Vz_grid >  3),
        Str_grid >= 0.1,
        Vz_grid > 5,
        Vz_grid > 4,
        ]
    choices = [3.3, 3.2, 3.1, 2.3, 2.2, 2.1, 2.3, 2.2, 2.1, 1.3, 2.1, 1.2]

    return np.select(conditions, choices, default=1.1).astype(np.float32) # --> (49, 1000, 2000) typeMPZ {1.1 - 3.3}

typeMPZ_grid = get_typeMPZ_hour(Vz_grid, Str_grid, Pr_grid)

# ── Saving for maps ─────
np.savez_compressed(OUTPUT_DIR / f'MPZ_{DATEHH}.npz', 
                    typeMPZ=typeMPZ_grid, 
                    lons = lons, 
                    lats = lats) 
            
print(f"MPZ grid done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

# ── Cities values extraction ─────
Vz_cities = Vz_grid[:, j_idx, i_idx].T # (n_cities, 49)
Str_cities = Str_grid[:, j_idx, i_idx].T
Pr_cities = Pr_grid[:, j_idx, i_idx].T
tadv_cities = tadv_grid[:, j_idx, i_idx].T
typeMPZ_cities = typeMPZ_grid[:, j_idx, i_idx].T
del Vz_grid, Str_grid, Pr_grid, tadv_grid, typeMPZ_grid

#per_city_vars['Vz'] = nint(per_city_vars['Vz'])  

# ── Summer afternoon worsening  ────
def apply_afternoon_worsening(typeMPZ_arr): 
    if 3 <= datestamp.month <= 11:#                                                                                                     ! check time_mect = nh-1 + dtimeCity(nc,n)
        ih_utc       = np.arange(cfg.nhours)                                       # (49,)
        local_hour   = (ih_utc[np.newaxis, :] + utc_offset[:, np.newaxis]) % 24    # (n_cities, 49)                                                       
        is_afternoon = (local_hour >= 13) & (local_hour <= 18)                     # (n_cities, 49)
        can_worsen   = typeMPZ_arr < 3.0                                   # (9, n_cities, 49)
        ia = is_afternoon if typeMPZ_arr.ndim == 2 else is_afternoon[None] # 2 dim for base case, 3 for pertrubed
        typeMPZ_arr = np.where(
                ia & can_worsen,
                typeMPZ_arr + 1.0,
                typeMPZ_arr).astype(np.float32)      
        print(f'Summer afternoon worsening applied (n_month={datestamp.month})')
    else:
        print(f'Summer afternoon worsening not applied (n_month={datestamp.month})')
    return typeMPZ_arr 

typeMPZ_cities = apply_afternoon_worsening(typeMPZ_cities)

print(f"Cities extraction done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 
#####################################################################################################################################################
### ─────────────────────────────────────────────────────────── Periodic MPZ ─────────────────────────────────────────────────────────────────────###
#####################################################################################################################################################

_season = select_period_bounds(datestamp)
bounds  = np.array(cfg.prd_bounds[_season], dtype=np.int32)   # (nperiods, 3)
adv_b   = np.array(cfg.adv_bounds,   dtype=np.int32)   # (nperiods, 3)

shift = utc_offset - 3 # (n_cities,)  relative to Moscow UTC+3                                                 # why shift from moscow time? 

ih = np.arange(cfg.nhours) # (nhours,)

# Here invalid negative indexing and masking for cities that out of forecast range
# out_of_range cleaning in the very rnd of the block 

# ── Str/Vz 3h aggregatioh ───────────────────────────────────────                                             # Why aggregation over 3 middle hours in the period? 
# (n_cities, nperiods) starts of 3-hours aggregation periods "we should get 10AM, in Dvstk MSK[5] = Dvstk 10AM"
str_start = bounds[np.newaxis, :, 0] - shift[:, np.newaxis] # contains negative vals for periods/cities out of range
# (n_cities, nperiods, nhours), True where hour in aggregation 3h period
str_mask  = (ih[np.newaxis, np.newaxis, :] >= str_start[:, :, np.newaxis]) \
          & (ih[np.newaxis, np.newaxis, :] <= str_start[:, :, np.newaxis] + 2) # contains some invalid Trues for periods/cities out of range

Str_p = np.nanmean(np.where(str_mask, Str_cities[:, np.newaxis, :], np.nan), axis=2) #(n_cities, nperiods) Str aggregated values
Vz_p  = nint(np.nanmean(np.where(str_mask, Vz_cities[:, np.newaxis, :], np.nan), axis=2)).astype(np.float32)   # float32 to preserve NaN for out-of-range city-period combos

# ── prec aggregatioh over the full period ───────────────────────────────────────  
pr_start = bounds[np.newaxis, :, 1] - shift[:, np.newaxis]
pr_end   = bounds[np.newaxis, :, 2] - shift[:, np.newaxis] # starts and ends of aggregation periods, with time shifts
pr_mask  = (ih[np.newaxis, np.newaxis, :] >= pr_start[:, :, np.newaxis]) \
         & (ih[np.newaxis, np.newaxis, :] <= pr_end[:, :, np.newaxis]) # (n_cities, nperiods, nhours), True where hour in aggregation period

Pr_p = np.nansum(np.where(pr_mask, Pr_cities[:, np.newaxis, :], np.nan), axis=2).astype(np.float32)  # (ncities, n_periods)
Pr_p = np.where(Pr_p < cfg.pr_min, 0.0, Pr_p)

# ── Advection: two 3-hour tadv tendencies, city-shifted ──────────────────────
tadv    = tadv_cities # (n_cities, nhours)                                                        
dt1_utc = adv_b[np.newaxis, :, 0] - shift[:, np.newaxis] # (n_cities, nperiods): hours indices for adv vals                                     #how advection is representative for the periods? 
dt2_utc = adv_b[np.newaxis, :, 1] - shift[:, np.newaxis]

dt1_val = np.take_along_axis(tadv, dt1_utc, axis=1) \
        - np.take_along_axis(tadv, dt1_utc - 3, axis=1) # (ncities, n_periods): advection in 3h period 1 
dt2_val = np.take_along_axis(tadv, dt2_utc, axis=1) \
        - np.take_along_axis(tadv, dt2_utc - 3, axis=1) # (ncities, n_periods): advection in 3h period 2 

same_sign = ((dt1_val > 0) & (dt2_val > 0)) | ((dt1_val < 0) & (dt2_val < 0)) # (ncities, n_periods): bool  True if advection in 3h periods the same or the opposite sign
opp_sign  = ((dt1_val > 0) & (dt2_val < 0)) | ((dt1_val < 0) & (dt2_val > 0))
Adv_p  = np.where(same_sign, dt1_val + dt2_val, 
                  np.where(opp_sign,  dt2_val, 0.0))  # (ncities, n_periods): advection vals used for correction later

# ── typeMPZ decision tree ──────────────────────────────────────────────────────
conditions = [
    Vz_p > 6,
    Str_p >= 0.9,
    Pr_p >= 1.0,
    (Str_p >= 0.5) & (Pr_p != 0),
    Str_p >= 0.5,
    (Str_p >= 0.1) & (Vz_p > 5) & (Pr_p != 0),
    (Str_p >= 0.1) & (Vz_p > 5),
    Str_p >= 0.1,                               # Vz<=5, weak inv. → 1.3
    Vz_p > 5,                                   # no inv., strong wind
    Vz_p > 4,                                   # no inv., moderate wind
]
choices = [3.3, 3.2, 3.1,
           2.3, 2.2,
           2.3, 2.1,
           1.3,
           2.1, 1.2]

typeMPZ_p = np.select(conditions, choices, default=1.1).astype(np.float32)  

# Summer afternoon bump typeMPZ + 1 for periods 1, 5 in Mar–Nov.
if 3 <= datestamp.month <= 11:
    p_afternoon = np.zeros((n_cities, cfg.nperiods), dtype=bool)
    p_afternoon[:, [1, 5]] = True # Fortran np=2 and np=6
    bump = p_afternoon & (Vz_p <= 5) & (typeMPZ_p < 2.0)
    typeMPZ_p = np.where(bump, typeMPZ_p + 1.0, typeMPZ_p).astype(np.float32)

# ── Out of forecast range values, = NaN, for example msk-8h hours are out of period 0  ─
out_of_range = (str_start < 0) | (pr_start < 0) | (pr_end < 0) | ((dt1_utc - 3) < 0) | ((dt2_utc - 3) < 0)   # (n_cities, nperiods)
typeMPZ_p = np.where(out_of_range, np.nan, typeMPZ_p)
Str_p = np.where(out_of_range, np.nan, Str_p)
Vz_p = np.where(out_of_range, np.nan, Vz_p)
Pr_p = np.where(out_of_range, np.nan, Pr_p)
Adv_p = np.where(out_of_range, np.nan, Adv_p)

# ── periodic typeMPZ correction from hourly typeMPZ ───────────────────────────────────────────────────
hourly      = typeMPZ_cities # (n_cities, nhours)
typeMPZ_p_res = np.empty((n_cities, cfg.nperiods), dtype=np.float32) # (n_cities, n_periods): empty, filled with typeMPZ {1.1...3.3} in next cycle

for p in range(cfg.nperiods):
    tree_val = typeMPZ_p[:, p] #(n_cities, ) raw typeMPZ values 
    # contains nans for out of range vals, they all False in any coditioning

    is_high = (tree_val >= 3.2) & (tree_val <= 3.3) # (n_cities, ): bool mask, correction not applied where typeMPZ = {3.1...3.3}

    p_window  = pr_mask[:, p, :] #(n_cities, n_hours): bool, True for period full window
    any_31 = ((hourly == 3.1) & p_window).any(axis=1) # (n_cities, ): bool, any hour within the period window = 3.1
    any_23 = ((hourly == 2.3) & p_window).any(axis=1) # (n_cities, ): bool, any hour within the period window = 2.3

    prev_is_31 = (typeMPZ_p_res[:, p - 1] == 3.1) if p > 0 else np.zeros(n_cities, dtype=bool) # (n_cities, ): bool any hour within the previous period window = 3.1 # careful with nans propagation
    print(p, hourly.shape,  any_31.shape, prev_is_31.shape, tree_val.shape)
    result = np.where(any_31,    3.1,
             np.where(any_23,    2.3,
             np.where(prev_is_31, 2.3,
                      tree_val))).astype(np.float32) # (n_cities, ): typeMPZ {1.1...3.3}
    
    # ── Advection correction ────────────────────────
    # Applied when typeMPZ {2.1...2.3}
    is_2x  = (result > 2.0) & (result < 2.4) # (n_cities, ): bool
    # Replicates Fortran stale-variable bug: parAdv_p in the output loop retains                                                             ! saved bug
    # the value from the last iteration of the preceding loop (np=Nperiods),
    # so every period uses period-6 advection instead of its own.
    adv    = Adv_p[:, -1] # (n_cities, ): (n_cities, ) bool 
    result = np.where(is_2x & (adv <  cfg.par_adv_minus_cr), 2.3,        result)
    result = np.where(is_2x & (adv >= cfg.par_adv_plus_cr),  result - 1, result) # (n_cities, ): typeMPZ {1.1...3.3}

    typeMPZ_p_res[:, p] = np.where(is_high, tree_val, result) # (n_cities, ): filling with typeMPZ {1.1...3.3}

# ── Integer MPZ class 1/2/3 ──────────────────────────────────────────────────
MPZ_p = np.where(out_of_range, np.nan, np.floor(typeMPZ_p_res)).astype(np.int32)
typeMPZ_p_res= np.where(out_of_range, np.nan, typeMPZ_p_res) # cleaning out_of_range

print(f"Periodic MPZ done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

#####################################################################################################################################################
### ─────────────────────────────────────────────────────────── Probablistic MPZ ─────────────────────────────────────────────────────────────────###
#####################################################################################################################################################

# IMPORTANT: summ_wors_hr must be True for typeMPZ_ensemble_cities to match Fortran's
# subtypeMPZ_city (afternoon worsening applied for months 3–11).

P_city = np.zeros((n_cities, cfg.nperiods, 3), dtype=np.int32) # (n_cities, nperiods, 3): dummy to fill

# ── 9 ensemple variants for Vz and Str ───
dv_arr   = np.array(cfg.dv,   dtype=np.float32)[:, None, None]  # pertrubation dVz variants for probablistic forecast
dStr_arr = np.array(cfg.dstr, dtype=np.float32)[:, None, None]  # pertrubation variants for probablistic forecast  

typeMPZ_cities_pert = get_typeMPZ_hour(Vz_cities[None] + dv_arr, 
                                       Str_cities[None] + dStr_arr, 
                                       Pr_cities[None])

typeMPZ_cities_pert = apply_afternoon_worsening(typeMPZ_cities_pert)

for p in cfg.prds_to_write:      # 0-based p=2..6 → Fortran np=3..Nperiods
    # ── Raw probabilities  ──────────────────────────────────────────────────
    nh0  = cfg.prob_srarts[p] - shift  # (n_cities,) city-shifted prob window start hour
    nlen = int(cfg.prob_lens[p]) # (n_cities,) city-shifted prob window length

    prob_mask = (ih[np.newaxis, :] >= nh0[:, np.newaxis]) & \
            (ih[np.newaxis, :] <  nh0[:, np.newaxis] + (nlen-1)) # (n_cities, nhours): bool, True where window for prob aggregating          ! changed nlen-1 like in Fortran

    m = prob_mask[np.newaxis]      # (1, n_cities, nhours): bool, mask broadcast over 9 variants
    t = typeMPZ_cities_pert        # (9, n_cities, nhours): {1.1, ..., 3.3}, pertrubated vars

    # (n_cities,): int, count of 1 MPZ subtype per all pert vars and hours within a period 
    cnt1 = np.sum((t < 2.0) & m, axis=(0, 2)).astype(np.int32)#                                                                              ! changed .astype(np.float32)
    cnt2 = np.sum((t >= 2.0) & (t < 3.0) & m, axis=(0, 2)).astype(np.int32) #...
    cnt3 = np.sum((t >= 3.0) & m, axis=(0, 2)).astype(np.int32)
    N    = cnt1 + cnt2 + cnt3 # (n_cities,): count of all MPZ subtypes

    P1 = nint(cnt1 * 100.0 / N).astype(np.int32) # (n_cities,): % of 1 MPZ type
    P2 = nint(cnt2 * 100.0 / N).astype(np.int32) # ...
    P3 = nint(cnt3 * 100.0 / N).astype(np.int32)

    # ─────── Tie-breaking: when rounded percentages tie, add one count from the deterministic result. ───────
    # Fortran only breaks specific tie patterns (not all ties):
    #   all_eq:   P1==P2==P3  → any det class eligible
    #   p12_p3z:  P1==P2, P3=0 → only det==1 or det==2
    #   p1z_p23:  P1=0, P2==P3 → only det==2 or det==3

    det = np.where(np.isnan(MPZ_p[:, p]), 0, MPZ_p[:, p]).astype(np.int32) # (n_cities,): intMPZ type, 0 for out-of-range (NaN) cities (nowhere because preiods 2..7)
    
    all_eq  = (P1 == P2) & (P1 == P3)  # (n_cities,): bool
    p12_p3z = (~all_eq) & (P1 == P2) & (P3 == 0)
    p1z_p23 = (~all_eq) & (P1 == 0)  & (P2 == P3)

    cnt1 += ((all_eq | p12_p3z) & (det == 1)).astype(np.int32)#                                                                              ! changed .astype(np.float32)
    cnt2 += ((all_eq | p12_p3z | p1z_p23) & (det == 2)).astype(np.int32)
    cnt3 += ((all_eq | p1z_p23) & (det == 3)).astype(np.int32)
    N    += (all_eq |
             (p12_p3z & ((det == 1) | (det == 2))) |
             (p1z_p23 & ((det == 2) | (det == 3)))).astype(np.int32)

    P1 = nint(cnt1 * 100.0 / N).astype(np.int32)
    P2 = nint(cnt2 * 100.0 / N).astype(np.int32)
    P3 = nint(cnt3 * 100.0 / N).astype(np.int32)

   # ───────  Normalize to sum=100: add remainder to the largest class; subtract from smallest if over. ───────
    # Fortran sequential if-chain → first-wins for ties, matching np.argmax/argmin.
    Parr   = np.stack([P1, P2, P3], axis=1)  # (n_cities, 3)
    dP     = (100 - Parr.sum(axis=1)) # (n_cities,): reminder of 100-sum
    target = np.where(dP > 0, np.argmax(Parr, axis=1), np.argmin(Parr, axis=1))
    Parr[np.arange(n_cities), target] += dP
    P_city[:, p, :] = Parr # (n_cities, nperiods, 3): probability in %, filling dummy

# ── Probability-based typeMPZ_p_res correction ────────────────────────────────────────
# Applied in Fortran's output table loop for np=3..Nperiods.
# If one class dominates with P > 55 %, shift the integer part of typeMPZ_p_res to match it.

for p in cfg.prds_to_write: #                                                                                                            ! here typeMPZ might change 2x, etc. 3.3 --> 1.1
    P_arr = P_city[:, p, :].astype(np.int32) # (n_cities, 3): probability in %                                                              
    cand  = np.where(P_arr > 55, P_arr, -1.0) # (n_cities, 3):  probability in %, -1 where $ ≤55 
    Pmax_val = cand.max(axis=1) # (n_cities,): probability in %, -1 when nothing exceeds 55
    Pmax_idx = np.argmax(cand, axis=1) + 1  # (n_cities,): {1-3} dominant class

    has_dominant = (Pmax_val > 0) # (n_cities,): bool, True where it is dominant class
    curr_int = np.floor(typeMPZ_p_res[:, p]).astype(np.int32) # (n_cities,): int current typeMPZ main class {1-3}
    dtype_p  = (curr_int - Pmax_idx).astype(np.float32) # (n_cities,): correction value {1-3}                                                                                      

    typeMPZ_p_res[:, p] = np.where(has_dominant, # (n_cities,): correction only where it is dominant probablistic class
                                   typeMPZ_p_res[:, p] - dtype_p, # (n_cities,): correction
                                   typeMPZ_p_res[:, p]) # (n_cities,): leave as it is if no dominant probablistic class
    
print(f"Probablistic MPZ done in: {(time.time() - t0)/60:.1f}m \n") 
t0 = time.time() 

#####################################################################################################################################################
### ─────────────────────────────────────────────────────────── Output tables ────────────────────────────────────────────────────────────────────###
#####################################################################################################################################################


periods = pd.DataFrame(typeMPZ_p_res[:, 2:], 
                       columns = [f'p_{i}_py' for i in cfg.prds_to_write], 
                       index = cities['name_rus'])

probs = pd.DataFrame(P_city[:,2:,:].reshape(n_cities,-1), 
                     columns = [f'p{x}_prob{i}' for x in range(5) for i in range(1,4)], 
                     index = cities['name_rus'])

full_tbl = pd.concat([periods, probs], axis = 1)
full_tbl['fo'] = fo_names

for fo in np.unique(fo_names): 
    full_tbl[full_tbl['fo'] == fo].to_excel(OUTPUT_DIR / 'tables' / f'mpz_{fo}.xlsx')

print(f"All MPZ done in: {(time.time() - t0)/60:.1f}m \n")  
