import os
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import starships.planet_obs as pl_obs
from astropy import units as u
from starships import correlation as corr
from starships.orbite import rv_theo_t
from starships.planet_obs import Observations

# seed the random number generator
rstate = np.random.default_rng(736109)


# Use scratch if available, use home if not.
try:
    base_dir = Path(os.environ['SCRATCH'])
except KeyError:
    base_dir = Path.home()

# Where the reduced data is saved
high_res_path = base_dir / Path(f'DataAnalysis/SPIRou/Reductions/WASP-127b_genest')

# The stem of the files where the infos are saved (retrieval ready files)
high_res_file_stem = f'retrieval_input_3-pc_mask_wings90_tr1'


path_model_file = Path('.')
specfile = Path('WASP-127_best_fit_H2O_atmosphere.npz')


# Read the model file.
model_file = np.load(path_model_file / specfile)
wv_high, model_high = model_file['wl'], model_file['dppm']

# %%
pl_name = 'WASP-127 b'

# --- Data parameters
pl_kwargs = {}
obs = Observations(name=pl_name, pl_kwargs=pl_kwargs)
planet = obs.planet
Kp_scale = (planet.M_pl / planet.M_star).decompose().value

kind_trans = 'transmission'

# - Which sequences are taken (always take only 1)
do_tr = [1]

# - Selecting bad exposures if wanted/needed
bad_indexs = None

## --- Additionnal global variables
inj_alpha = 'ones'
idx_orders = slice(None)
nolog = True

# Choose over which axis the logl is summed.
# -1 (or equivalently 2) should always be present (sum over spectral axis)
# It is possible to sum over multiple axis, like orders and spectra
# with (-2, -1) or equivalently, (1, 2).
axis_sum = -1  # axis along which the logl is summed.

# -----------------------------------------------------------
# LOAD HIGHRES DATA
data_info, data_trs = pl_obs.load_sequences(high_res_file_stem, do_tr, path=high_res_path)
data_info['bad_indexs'] = bad_indexs

# Save (big) arrays that will be shared between processes.
# NOTE: Their values must not be changed by the processes.
out = [(obj, key) for key, obj in data_trs['0'].items()
       if isinstance(obj, np.ndarray)]
shared_arrays, shared_keys = zip(*out)  # transpose result
shared_arrays = np.array(shared_arrays, dtype=object)

# If non array objects are needed, it is better to define them separately (below)
# as global variables.
pca = data_trs['0']['pca']

# Define util functions to get the index of a key in the shared arrays.
def get_shared_array_index(*args):
    """Get the index of a key in the shared arrays"""
    idx_list = [shared_keys.index(key) for key in args]
    return idx_list
    

def calc_chi2_terms(model, idx_ord=None, axis=None):
    """Compute the model dependent terms of the chi2."""
    
    # Get values from global variables if not provided.
    if idx_ord is None:
        idx_ord = idx_orders
        
    if axis is None:
        axis = axis_sum
        
    # Get the index of the shared arrays and the arrays themselves.
    idx_list = get_shared_array_index('flux', 'noise')
    flux, noise = shared_arrays[idx_list]
    
    # Divide model by the uncertainty
    # NOTE: flux is already divided by the uncertainty.
    model = model[:, idx_ord] / noise[:, idx_ord]

    # Compute each terms of the chi2
    f_x_g = np.ma.sum(model * flux[:, idx_ord], axis=axis) 
    s2g = np.ma.sum(model**2, axis=axis)

    return f_x_g, s2g
    
def get_chi2_detailed(theta):
    """Return the model dependent terms of the chi2 for each exposure and each order."""
    v_sys, kp = theta

    # --- Computing the logL for all sequences
    # We need from data_tr: RV_const, t_start, wave, sep, pca, params.
    idx_list = get_shared_array_index('RV_const', 't_start', 'wave', 'sep', 'params')
    rv_const, t_start, wave, sep, params = shared_arrays[idx_list]
    
    vrp_orb = rv_theo_t(kp, t_start * u.d, planet.mid_tr,
                        planet.period, plnt=True).value
    
    # Get the model sequence.
    n_pc = int(params[5])
    velocities = v_sys + vrp_orb - vrp_orb * Kp_scale + rv_const
    model_seq = corr.gen_model_sequence_noinj(velocities,
                                              data_wave=wave,
                                              data_sep=sep,
                                              data_pca=pca,
                                              data_npc=n_pc,
                                              planet=planet,
                                              model_wave=wv_high[20:-20],
                                              model_spec=model_high[20:-20],
                                              kind_trans=kind_trans,
                                              alpha=data_info['trall_alpha_frac']
                                              )

    # Calculate the log likelihood.
    chi2_terms = calc_chi2_terms(model_seq)
    
    return chi2_terms

get_chi2_detailed(np.array([ 0, 130]))

kp, vsys = np.meshgrid(np.linspace(0., 280., 60),
                     np.linspace(-80., 50., 120))

n_process = 2 * 32
print(f"Preparing the map with pool of {n_process} processes...")
with Pool(n_process) as pool:
    outputs = pool.map(get_chi2_detailed, np.array([np.ravel(vsys), np.ravel(kp)]).T)
outputs = np.array(outputs)
data_shape = outputs.shape[2:]
cross_terms, squared_terms = [np.reshape(outputs[:, idx], (*kp.shape, *data_shape))
                              for idx in range(2)]


# Compute useful quantities for logl computations.
for data_tr in data_trs.values():
    
    # Pre-compute all values for the logl that are independent of the model, for all orders.
    uncert_sum = np.sum(np.ma.log(data_tr['noise'][:, idx_orders]), axis=axis_sum)
    s2f= np.sum(data_tr['flux'][:, idx_orders]**2, axis=axis_sum)


# Save logl_map and kp, vsys for later use
out_path = Path('chi2_maps')

# Create the output directory if it doesn't exist.
if not out_path.exists():
    out_path.mkdir()
filename = f'chi2_map_detailed_{high_res_file_stem}_large.npz'
print(f"Saving to {out_path / filename}")

# Check the arrays to be saved.
saved_arrays = dict(squared_terms=squared_terms, cross_terms=cross_terms,
                    kp=kp, vsys=vsys,
                    alpha_frac=data_info['trall_alpha_frac'],
                    icorr = data_info['trall_icorr'],
                    N = data_info['trall_N'],
                    bad_indexs = data_info['bad_indexs'],
                    s2f=s2f, uncert_sum=uncert_sum)

# Make sure that all arrays are saved as empty arrays if they are not np.ndarray
# or if the dtype is object.
for key, obj in saved_arrays.items():
    if isinstance(obj, np.ndarray):
        if obj.dtype == 'object':
            print(f"{key}: {obj.shape}, {obj.dtype}")
            print(f"Saving {key} as an empty array.")
            saved_arrays[key] = np.empty(0)
    else:
        print(f"NOT AN ARRAY: type({key}): {type(obj)}")
        print(f"Saving {key} as an empty array.")
        saved_arrays[key] = np.empty(0)

np.savez(out_path / filename, **saved_arrays)