#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 10 11:29:45 2016

@author: mints

A module to deal with N-dimensional point distributions.
Produces rectangular grid with each cell containing at
least a given number of points. Median division is used,
with an option to tie cell borders to a given grid.
"""
import numpy as np


def split_data(arr, column, value):
    """
    Split array by value in column.
    """
    condition = arr[:, column] <= value
    return arr[condition], arr[~condition]


def get_bounding_box(data, ndims=2):
    """
    Get bounding box for the first ndims.
    """
    return np.hstack((np.min(data[:, :ndims], axis=0),
                      np.max(data[:, :ndims], axis=0)))


def bring_to_grid(items, grid, lower_side=True):
    """
    Get grid point index for items on the grid.
    If :lower_side: than return left grid point,
    otherwise return right grid point.
    """
    if not isinstance(items, list) and not isinstance(items, np.ndarray):
        items = np.array([items])
    pos = grid.searchsorted(items)
    if lower_side:
        pos[pos == 0] = 1
        return pos - 1
    elif np.any(pos == len(grid)):
        pos[pos == len(grid)] = len(grid) - 1
        return pos
    else:
        return pos


def get_bounding_box_grid(data, x_bins, y_bins):
    """
    Get bounding box for data on the grid.
    """
    boundaries = list(data.min(axis=0)[:2]) + list(data.max(axis=0)[:2])
    boundaries[0] = bring_to_grid(boundaries[0], x_bins)[0]
    boundaries[1] = bring_to_grid(boundaries[1], y_bins)[0]
    boundaries[2] = bring_to_grid(boundaries[2], x_bins, False)[0]
    boundaries[3] = bring_to_grid(boundaries[3], y_bins, False)[0]
    return boundaries


def get_shannon(x):
    """
    Returns Shannon entropy S.
    """
    h = np.histogram(x, bins=len(x))[0] / float(len(x))
    h = h[h > 0]
    return -np.sum(h * np.log(h))


def get_shannon_h(h):
    h = h[h > 0]
    return -np.sum(h * np.log(h))


def get_next_axis_entropy(data, axis_flags, background=None):
    entropies = np.zeros_like(axis_flags, dtype=float)
    for axis in np.where(axis_flags)[0]:
        entropies[axis] = get_shannon(data[:, axis])
        if background is not None:
            entropies[axis] += get_shannon_h(background.sum(axis=axis))
    return np.argmin(entropies)


def get_next_axis(current, axis_flags):
    """
    Get next column for an array (cycle through),
    using only those for which axis_flags == True.
    """
    # Process left and right parts
    next_axis = current + 1
    if next_axis == len(axis_flags):
        next_axis = 0
    while not axis_flags[next_axis]:
        next_axis = next_axis + 1
        if next_axis == len(axis_flags):
            next_axis = 0
    return next_axis


def binary_split_2d(data, boundaries=None, split_axis=0,
                    min_occupation=20):
    """
    Split 2D data into cell containing at least :min_occupation:
    points.

    :data:  numpy array with two dimensions;
    :boundaries: array of 4 values (left, bottom, right, top)
                 defining the position of cell boundary
    :split_axis: number of axis along which a
                 first division will be made.
    :min_occupation: minimum number of points in a cell

    :result: an iterator with (boundaries, data) for
    each cell. Here boundaries is 4 numbers (as in
    input parameter) and data is part of the input
    data in a returned cell.
    """
    if boundaries is None:
        boundaries = get_bounding_box(data)
    if len(data) < 2*min_occupation:
        # Cannot split further, return and exit
        yield boundaries, data
        return
    # Define the location of split point
    division = np.median(data[:, split_axis])
    left_data, right_data = split_data(data, split_axis, division)
    # Prepare boundaries for the left and right parts.
    left_boundaries = np.copy(boundaries)
    right_boundaries = np.copy(boundaries)
    if split_axis == 0:
        left_boundaries[2] = division
        right_boundaries[0] = division
    else:
        left_boundaries[3] = division
        right_boundaries[1] = division
    # Process left and right parts
    for answer in binary_split_2d(left_data, left_boundaries,
                                  1-split_axis, min_occupation):
        yield answer
    for answer in binary_split_2d(right_data, right_boundaries,
                                  1-split_axis, min_occupation):
        yield answer


def binary_split_2d_grid(data, x_bins, y_bins,
                         boundaries=None, split_axis=0,
                         repeated_attempt=False, min_occupation=20,
                         bg_data=None):
    """
    Split 2D data into cell containing at least :min_occupation:
    points. Cell borders are aligned with x_bins or y_bins positions.

    :data:  numpy array with two dimensions
    :x_bins: array of allowed positions for cell borders for axis=0
    :y_bins: array of allowed positions for cell borders for axis=1
    :boundaries: array of 4 values (left, bottom, right, top)
                 defining the position of cell boundary
    :split_axis: number of axis along which a
                 first division will be made.
    :repeated_attempt: for internal use only.
    :min_occupation: minimum number of points in a cell.
    :bg_data: background data histogram, defined on the same grid (UNUSED).

    :result: an iterator with (boundaries, data) for
    each cell. Here boundaries is 4 numbers (as in
    input parameter) and data is part of the input
    data in a returned cell.

    """
    if boundaries is None:
        boundaries = get_bounding_box_grid(data, x_bins, y_bins)
    if len(data) < 2*min_occupation:
        # Cannot split further, return and exit
        yield boundaries, data
        return
    if boundaries[2] - boundaries[0] == 1:
        split_axis = 1
    elif boundaries[3] - boundaries[1] == 1:
        split_axis = 0
    else:
        split_axis = get_next_axis_entropy(data, np.ones(2, dtype=bool),
                                           background=bg_data)
    # Define the location of split point
    division = np.median(data[:, split_axis])
    # Move split point to the nearest grid point
    if split_axis == 0:
        div_bin = bring_to_grid(division, x_bins)[0]
        div_pos = x_bins[div_bin]
    else:
        div_bin = bring_to_grid(division, y_bins)[0]
        div_pos = y_bins[div_bin]
    # Split the data
    left_data, right_data = split_data(data, split_axis, div_pos)
    if len(left_data) < min_occupation or len(right_data) < min_occupation:
        # Split resulted in small occupancy,
        # try splitting by the next point...
        div_bin = div_bin + 1
        if split_axis == 0:
            if div_bin < len(x_bins):
                div_pos = x_bins[div_bin]
        else:
            if div_bin < len(y_bins):
                div_pos = y_bins[div_bin]
        left_data, right_data = split_data(data, split_axis, div_pos)
        if len(left_data) < min_occupation or len(right_data) < min_occupation:
            # ...still low occupancy. Try to split by the other axis.
            if not repeated_attempt:
                for answer in binary_split_2d_grid(data, x_bins, y_bins,
                                                   boundaries,
                                                   1-split_axis, True,
                                                   min_occupation,
                                                   bg_data=bg_data):
                    yield answer
            else:
                yield boundaries, data
            return
    # Prepare boundaries for the left and right parts.
    left_boundaries = np.copy(boundaries)
    right_boundaries = np.copy(boundaries)
    if split_axis == 0:
        left_boundaries[2] = div_bin
        right_boundaries[0] = div_bin
    else:
        left_boundaries[3] = div_bin
        right_boundaries[1] = div_bin
    # Process left and right parts
    if bg_data is not None:
        bg_data_left = bg_data[left_boundaries[0]:left_boundaries[2],
                               left_boundaries[1]:left_boundaries[3]]
        bg_data_right = bg_data[right_boundaries[0]:right_boundaries[2],
                                right_boundaries[1]:right_boundaries[3]]
    else:
        bg_data_left = None
        bg_data_right = None
    for answer in binary_split_2d_grid(left_data, x_bins, y_bins,
                                       left_boundaries,
                                       1-split_axis,
                                       min_occupation=min_occupation,
                                       bg_data=bg_data_left):
        yield answer
    for answer in binary_split_2d_grid(right_data, x_bins, y_bins,
                                       right_boundaries,
                                       1-split_axis,
                                       min_occupation=min_occupation,
                                       bg_data=bg_data_right):

        yield answer


def binary_split_nd(data, boundaries=None, ndims=None, min_occupation=20,
                    split_axis=0, active_axes=None, minimum_cell_size=None):
    """
    Split data in n dimensions.

    Parameters
    ----------

    data : np.array Input data
    boundaries : float[4] Pre-defined data boundaries (default: detect from
                 the data itself)
    ndims : int Number of columns in data (default: detect from
                 the data itself)
    min_occupation : int Minimum number of points in the cell (default = 20)
    split_axis : int First axis to split along (default = 0)
    active_axes : int Boolean array indicating which axes are used
                      (default = all)
    minimum_cell_size : float or list Minimum size of a cell (default = 0)
    """
    if ndims is None:
        ndims = data.shape[1]
    if boundaries is None:
        boundaries = get_bounding_box(data, ndims)
    if active_axes is None:
        active_axes = np.ones(ndims, dtype=bool)
    if minimum_cell_size is None:
        minimum_cell_size = np.zeros(ndims, dtype=float)
    elif not isinstance(minimum_cell_size, list):
        minimum_cell_size = np.ones(ndims) * minimum_cell_size
    if len(data) < 2*min_occupation or not np.any(active_axes):
        # Cannot split further, return and exit
        yield boundaries, data
        return
    # Define the location of split point
    minimum_cell = minimum_cell_size[split_axis]
    while True:
        while boundaries[ndims + split_axis] - boundaries[split_axis] \
              < 2.*minimum_cell:
            active_axes[split_axis] = False
            if not np.any(active_axes):
                yield boundaries, data
                return
            split_axis = get_next_axis_entropy(data, active_axes)
            minimum_cell = minimum_cell_size[split_axis]
        division = np.median(data[:, split_axis])
        if division - boundaries[split_axis] < minimum_cell:
            division = boundaries[split_axis] + minimum_cell
        elif boundaries[ndims + split_axis] - division < minimum_cell:
            division = boundaries[ndims + split_axis] - minimum_cell
        left_data, right_data = split_data(data, split_axis, division)
        if len(left_data) < min_occupation or len(right_data) < min_occupation:
            active_axes[split_axis] = False
            if not np.any(active_axes):
                yield boundaries, data
                return
            split_axis = get_next_axis_entropy(data, active_axes)
            minimum_cell = minimum_cell_size[split_axis]
        else:
            break
    # Prepare boundaries for the left and right parts.
    left_boundaries = np.copy(boundaries)
    right_boundaries = np.copy(boundaries)
    left_boundaries[ndims + split_axis] = division
    right_boundaries[split_axis] = division
    lactive_axes = np.copy(active_axes)
    ractive_axes = np.copy(active_axes)
    # Process left and right parts
    next_axis = get_next_axis_entropy(data, active_axes)
    for answer in binary_split_nd(left_data, left_boundaries, ndims,
                                  min_occupation, next_axis, lactive_axes,
                                  minimum_cell_size):
        yield answer
    for answer in binary_split_nd(right_data, right_boundaries, ndims,
                                  min_occupation, next_axis, ractive_axes,
                                  minimum_cell_size):
        yield answer


def _bbox(bound):
    """
    Convert boundaries to bounding box point coordinates.
    """
    pos = [[bound[0], bound[1]],
           [bound[0], bound[3]],
           [bound[2], bound[3]],
           [bound[2], bound[1]],
           [bound[0], bound[1]]]
    return np.array(pos)


def bbox_data(data, box):
    """
    Select data in the box.
    """
    data = data[data[:, 0] > box[0]]
    data = data[data[:, 0] <= box[2]]
    data = data[data[:, 1] > box[1]]
    data = data[data[:, 1] <= box[3]]
    return data


def plot(result, grids=None,
         shrink=False, show=True, bbox_full=None):
    """
    Plot result of binary_split_* functions.
    If grids are given, than boundaries are assumed to be indices in
    grids.
    """
    from itertools import cycle
    import pylab as plt
    colors = cycle('krbmcyg')
    for box, data in result:
        color = next(colors)
        if shrink:
            if grids is not None:
                if np.any(box == bbox_full) and shrink and bbox_full is not None:
                    if box[0] == bbox_full[0]:
                        box[0] = np.searchsorted(grids[0], data[:, 0].min()) - 1
                    if box[1] == bbox_full[1]:
                        box[1] = np.searchsorted(grids[1], data[:, 1].min()) - 1
                    if box[2] == bbox_full[2]:
                        box[2] = np.searchsorted(grids[0], data[:, 0].max())
                    if box[3] == bbox_full[3]:
                        box[3] = np.searchsorted(grids[1], data[:, 1].max())
                bboxx = _bbox(box)
                bboxx2 = np.asarray(bboxx, dtype=float)
                bboxx2[:, 0] = grids[0][bboxx[:, 0]]
                bboxx2[:, 1] = grids[1][bboxx[:, 1]]
                bboxx = bboxx2
            else:
                bboxx = get_bounding_box(data)
                bboxx = _bbox(bboxx)
        else:
            bboxx = _bbox(box)
            if grids is not None:
                bboxx2 = np.asarray(bboxx, dtype=float)
                bboxx2[:, 0] = grids[0][bboxx[:, 0]]
                bboxx2[:, 1] = grids[1][bboxx[:, 1]]
                bboxx = bboxx2
        plt.plot(bboxx[:, 0], bboxx[:, 1], color=color)
        plt.text(bboxx[:, 0].mean(),
                 bboxx[:, 1].mean(),
                 str(len(data)), )
        plt.scatter(data[:, 0], data[:, 1], c=color, s=5)
    if show:
        plt.show()


if __name__ == '__main__':
    xmap = np.ones((10, 10))
    sample = np.random.multivariate_normal(mean=(5., 5.),
                                           cov=[[1., 0.29], [0.29, 1.]],
                                           size=(1000, 2))
    sample = sample[sample[:, 0] > 0]
    sample = sample[sample[:, 0] < 20]
    sample = sample[sample[:, 1] > 0]
    sample = sample[sample[:, 1] < 20]
    plot(binary_split_nd(sample, min_occupation=50,
                         minimum_cell_size=[1., 1.]))
