"""Utils."""

import os
import boto3
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#import cartopy.crs as ccrs
#import skimage.measure
#from pybufr_ecmwf.raw_bufr_file import RawBUFRFile
#from pybufr_ecmwf.bufr_interface_ecmwf import BUFRInterfaceECMWF
import sklearn
from sklearn.neighbors import KNeighborsRegressor
from .constants import *


def upload_model(model_name, bucket_folder=''):
    '''
    Uploads model file to S3 w/ same name.
    Args:
        model_name: local model filename.
    '''
    s3 = boto3.client('s3', aws_access_key_id=AWS_SERVER_PUBLIC_KEY,
                      aws_secret_access_key=AWS_SERVER_SECRET_KEY, region_name=REGION_NAME)
    s3.upload_file('/tmp/' + model_name,
                   'phd-interpolation-bucket', bucket_folder + '/' + model_name)


def decode_message(words, section_sizes, section_start_locations):
    '''
    Decodes a BUFR message into a Pandas Dataframe.
    Input: BUFR Message Parameters (words, section_sizes, section_start_locations).
    Output: Decoded Pandas Dataframe (columns of interest filtered).
    '''
    # we read the msg.
    bufr = BUFRInterfaceECMWF(encoded_message=words,
                              section_sizes=section_sizes,
                              section_start_locations=section_start_locations,
                              verbose=False)

    # we decode something i dunno.
    bufr.decode_sections_012()

    # we setup the tables.
    bufr.setup_tables(tables_dir='/Users/mohamedakramzaytar/data/bufr_tables/')

    # decode the BUFR Data.
    bufr.decode_data()

    # we decode some other stuff.
    bufr.decode_sections_0123()

    # we fill something i don't know.
    bufr.fill_descriptor_list_subset(subset=1)

    # We extract the rows now by splitting the long message up.
    rows = []
    index_max = int(len(bufr.values)/37) + 1  # 37: number of columns.
    for i in range(1, index_max):
        rows.append(bufr.values[37*(i-1):37*i])
    rows = np.array(rows)

    # OZONE PRODUCT COLUMNS.
    columns = ['SAT_id', 'CENTER_id', 'SAT_class', 'NADIR_x', 'NADIR_y',
               'YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE',
               'SECOND', 'LATITUDE', 'LONGITUDE', 'ROW_num', 'COL_num',
               'TOTAL_ozone', 'Quality_info', 'Bit_map', 'indicator', 'indicator',
               'indicator', 'indicator', 'indicator', 'indicator', 'indicator',
               'indicator', 'indicator', 'indicator', 'indicator', 'indicator',
               'indicator', 'indicator', 'indicator', 'indicator',
               'CENTER_id', 'APP', 'TOTAL_Ozone_Quality']

    # COLS OF INTEREST.
    cols = ['YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE', 'SECOND',
            'LATITUDE', 'LONGITUDE', 'TOTAL_ozone', 'Quality_info',
            'TOTAL_Ozone_Quality']

    # Finally we can create the dataframe with the columns of interest.
    df = pd.DataFrame(rows, columns=columns)[cols]

    df_ = df.loc[(df['LONGITUDE'] > -17.1) & (df['LONGITUDE'] < -0.84)
                 & (df['LATITUDE'] > 21.21) & (df['LATITUDE'] < 35.98)].copy()

    # return only points in Morocco.
    return df_


def bufr_to_dataframe(filepath):
    '''
    Decodes one BUFR file into one Pandas Dataframe.
    Input: BUFR file Path.
    Output: Associated Pandas Dataframe.
    '''

    # first let's initialize an empty dataframe to append to it.
    cols = ['YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE', 'SECOND',
            'LATITUDE', 'LONGITUDE', 'TOTAL_ozone', 'Quality_info',
            'TOTAL_Ozone_Quality']
    d_f = pd.DataFrame(columns=cols)

    # let's read the file.
    bufr_file = RawBUFRFile(verbose=False)
    bufr_file.open(filename=filepath, mode='rb')

    # let's loop over the messages, decode them and add them to df.
    for _ in range(1, bufr_file.get_num_bufr_msgs() + 1):
        (ws, s_s, s_s_l) = bufr_file.get_next_raw_bufr_msg()
        msg_df = decode_message(ws, s_s, s_s_l)
        d_f = d_f.append(msg_df)

    # we close the file.
    bufr_file.close()

    return d_f


def location_to_index(step=0.01):
    '''
    This function creates dictionnaries to get the indexes using latitude/longitude in any grid.
    Output: latitude_index, longitude_index.
    '''

    # start with the lats.
    raw_lats = np.arange(start=21.22-0.01, stop=35.97+0.01, step=step)
    lats = [np.around(lat, decimals=2) for lat in raw_lats]
    # we create the latitude to index directory.
    lat_index = {}
    for i, lat in enumerate(lats):
        lat_index[lat] = i

    # and the lons.
    raw_lons = np.arange(start=-17.09-0.01, stop=-0.85+0.01, step=step)
    lons = [np.around(lon, decimals=2) for lon in raw_lons]
    # we create the longitude to index directory.
    lon_index = {}
    for i, lon in enumerate(lons):
        lon_index[lon] = i

    return (lat_index, lon_index)


def scan_to_matrix(scan):
    '''
    Produces the Matrix Equivalent of a scan dataframe, and the associated location to index.
    Input: Scan Dataframe.
    Output: Matrix Equivalent, Location to Index Dictionnary.
    '''
    
    # 0. we get the location to index dictionnaries for the fine and final grids.
    lat_index, lon_index   = location_to_index(0.01)
    grid_lat_i, grid_lon_i = location_to_index(0.09)
    
    # 1. we create the opriginal fine grid.
    fine_grid = np.zeros(shape=(1477, 1627), dtype='float32')
    
    # 2. we fill the fine grid with the scan pollution values.
    for _, row in scan.iterrows():
        fine_grid[lat_index[row['LATITUDE']], lon_index[row['LONGITUDE']]] = row['TOTAL_ozone']
    
    # 3. we max-pool it, getting the new grid.
    grid = skimage.measure.block_reduce(fine_grid, (9,9), np.max)
    
    return grid, grid_lat_i, grid_lon_i


def matrix_to_df(M, lat_index, lon_index):
    '''
    Turns the Final Grid into a DataFrame.
    Input: Grid Matrix and its location to index.
    Output: Final Pandas Dataframe.
    '''
    # we get the ys
    o3 = M.flatten()

    # let's get the location data.
    lats, lons = [], []
    for lat in lat_index.keys():
        for lon in lon_index.keys():
            lats.append(lat)
            lons.append(lon)

    return pd.DataFrame(data={'lats': lats, 'lons': lons, 'o3': o3})


def plot_df(df):
    '''
    Turns a final dataframe into a plot over morocco.
    Input: The Pollution Dataframe.
    Output: None.
    '''

    plt.clf()

    fig = plt.figure(figsize=(10, 10), dpi=99)
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())

    ax.set_extent([-17.1, -0.84, 21.21, 35.98], crs=ccrs.PlateCarree())

    ax.coastlines(resolution='10m',)

    s = ax.scatter(df[df.o3 != 0]['lons'].values, df[df.o3 != 0]['lats'].values,
                   marker='.', s=7, c=df[df.o3 != 0]['o3'].values, zorder=5, cmap='jet')
    ax.scatter(df[df.o3 == 0]['lons'].values, df[df.o3 == 0]
               ['lats'].values, marker='.', s=3, c='gray', zorder=1)

    plt.show()


def extract_patches(M, N, eps_max, eps_min):
    '''
    Ouputs a list of extracted matrices from matrix M that represent patches of N by N measurements ..
    .. that have less than `eps_max` frequency of zeros and more than `eps_min` of zeros.
    '''
    
    # get the ranges to loop over the Big Matrix M.
    i_range = range(0, M.shape[1]-N+1)
    j_range = range(0, M.shape[0]-N+1)
    
    # now we init the patches list and append to it as we go.
    patches = []

    for i in i_range:
        for j in j_range:
            zeros_percentage = collections.Counter(list(M[i:i+N, j:j+N].flatten()))[0]/(N*N)
            if zeros_percentage <= eps_max and zeros_percentage >= eps_min:
                patches.append(M[i:i+N, j:j+N])
    return patches


def interpolate_matrix(M):
    '''
    Takes a raw Matrix and uses Raw KNN to interpolate the zeros based on the surrounding non-zero pixels, gives back the interpolated Matrix with no zeros.
    Input: Raw Matrix M.
    Output: Its Interpolated Matrix.
    '''

    # first we get the range.
    i_range = list(range(0, len(M[0])))
    j_range = list(range(0, len(M[0, :])))

    # we now check each point and return valid ones.
    for i in range(len(M[0])):
        for j in range(len(M[0, :])):
            if M[i, j] == 0:
                neighbors = []
                if j-1 in j_range and M[i, j-1] != 0:
                    neighbors.append(M[i, j-1])
                if j+1 in j_range and M[i, j+1] != 0:
                    neighbors.append(M[i, j+1])
                if i+1 in i_range:
                    if M[i+1, j] != 0:
                        neighbors.append(M[i+1, j])
                    if j-1 in j_range and M[i+1, j-1] != 0:
                        neighbors.append(M[i+1, j-1])
                    if j+1 in j_range and M[i+1, j+1] != 0:
                        neighbors.append(M[i+1, j+1])
                if i-1 in i_range:
                    if M[i-1, j] != 0:
                        neighbors.append(M[i-1, j])
                    if j-1 in j_range and M[i-1, j-1] != 0:
                        neighbors.append(M[i-1, j-1])
                    if j+1 in j_range and M[i-1, j+1] != 0:
                        neighbors.append(M[i-1, j+1])
                if len(neighbors) == 0:
                    # if there are no neighbors, fill in the average of the whole scan.
                    M[i, j] = M.sum()/M.size
                else:
                    M[i, j] = sum(neighbors)/len(neighbors)
    return M


def scans_to_patches(scans_folder='', scans_count=7, image_dim=30, missing_rate=5):
    '''
    Extracts all patches from scans in `scans_folder` and return them as a numpy array.
    Input:
        scans_folder: folder path where the scans live.
        scans_count: max number of scans to be extracted from, from the scans_folder.
        patches_filepath: the path of the patches numpy file to be outputed.
        image_dim: the width/height of the patches you want extracted (default: 30x30).
        missing_rate: if the missing pixels are <= missing_rate/100, the patch will be extracted, if not ignored.
    Output:
        The extracted patches as a numpy array.
    '''
    # first we list the scan files names.
    scans_file_names = os.listdir(scans_folder)

    # we init the patches list.
    extracted_patches = []

    # for each scan file, we extract the patches then add them into the extracted patches list.
    for scan_file_name in scans_file_names[:scans_count]:
        scan_df = bufr_to_dataframe(scans_folder + scan_file_name)
        scan_matrix, _, _ = scan_to_matrix(scan_df)
        scan_patches = extract_patches(
            scan_matrix, image_dim, missing_rate/100)
        for scan_patch in scan_patches:
            extracted_patches.append(scan_patch)

    # after extracting all possible patches, we turn them into a numpy array and return them.
    return np.array(extracted_patches)
