import React, { useRef, useEffect, useCallback, useState, useMemo } from 'react';
import { useSelector, useDispatch } from 'react-redux';
import assign from 'lodash/assign';
import isEmpty from 'lodash/isEmpty';
import reduce from 'lodash/reduce';
import isFinite from 'lodash/isFinite';
import { sum } from 'd3-array';

// constants
import {
    COLOR_BY_OPTIONS,
    FETCH_STATE,
    TAXONOMY_ID_ACCESSOR,
    ZERO,
} from '../../constants';

// hooks and utils
import useWindowWidth from '../../utils/use-window-width';
import useWindowHeight from '../../utils/use-window-height';

// components
import { Loader } from '../../components/indicators';
import {
    CellCountItem,
    CellCountHeader,
    CellExpressionHeader,
    CellExpressionHistogram,
    ScatterPlotChart,
} from '../../components/scatter-plot';
import InfoBox from '../../components/info-box';
import ScaleBarContainer from '../scale-bar';

// selectors
import {
    getScatterPlotPointDataFetchState,
    getScatterPlotPointData,
    getScatterPlotGeneExpressionFetchState,
    getScatterData,
    getScatterPlotGeneExpression,
    getClusterInfoById,
    getViewState,
    hasScatterPlotPointData,
} from '../../selectors/scatter-plot-selectors';
import { getSelectedDatasetName, getSelectedDatasetScatterPlotName } from '../../selectors';
import { getColorByFeature, getColorByFeatureValue } from '../../selectors/color-by-selectors';
import {
    getNodeByFunction,
    getSelectedLeafIdMap,
} from '../../selectors/taxonomy-selectors';
import { getColorScale } from '../../selectors/color-scale-selectors';

// queries
import {
    getScatterPlotPointDataQuery,
    getScatterPlotGeneExpressionQuery,
} from '../../queries';

// actions
import { fetchData, changeClusterSelection, removeClusterSelection } from '../../actions';
import {
    SCATTER_PLOT_FETCH_POINT_DATA,
    SCATTER_PLOT_RECEIVE_POINT_DATA,
    SCATTER_PLOT_ERROR_POINT_DATA,
    SCATTER_PLOT_FETCH_GENE_EXPRESSION,
    SCATTER_PLOT_RECEIVE_GENE_EXPRESSION,
    SCATTER_PLOT_ERROR_GENE_EXPRESSION,
    updateViewStateAction,
} from '../../actions/scatter-plot-actions';

import './style.scss';
import useReferenceWrapper from '../../utils/use-reference-wrapper';

const fetchScatterPlotPointData = (dispatch, selectedDatasetName, selectedDatasetScatterPlotName) => {
    const graphqlQuery = getScatterPlotPointDataQuery(selectedDatasetName, selectedDatasetScatterPlotName);
    const metadata = { selectedDatasetName, selectedDatasetScatterPlotName };
    dispatch(fetchData(
        graphqlQuery,
        SCATTER_PLOT_FETCH_POINT_DATA,
        SCATTER_PLOT_RECEIVE_POINT_DATA,
        SCATTER_PLOT_ERROR_POINT_DATA,
        metadata,
    ));
};

const fetchScatterPlotGeneExpression = (dispatch, selectedDataset, colorByFeatureValue) => {
    const graphqlQuery = getScatterPlotGeneExpressionQuery(selectedDataset, colorByFeatureValue);
    const metadata = { selectedDataset };
    dispatch(fetchData(
        graphqlQuery,
        SCATTER_PLOT_FETCH_GENE_EXPRESSION,
        SCATTER_PLOT_RECEIVE_GENE_EXPRESSION,
        SCATTER_PLOT_ERROR_GENE_EXPRESSION,
        metadata,
    ));
};

const weightedAverage = (itemData) => {
    const denominator = sum(itemData.map(d => d.count));
    const numerator = sum(itemData.map(d => d.count * d.clusterAverage));
    if (denominator === 0) {
        return NaN;
    }

    return numerator / denominator;
};

const computeInfoBoxProps = (selectedClustersHash, hoveredClusters, clusterInfoById, height, colorByFeature, colorByFeatureValue, showColorScaleBar) => {
    let itemData = [];

    if (!isEmpty(hoveredClusters)) {
        itemData = hoveredClusters.reduce((acc, id) => {
            if (clusterInfoById[id]) {
                acc.push(clusterInfoById[id]);
            }
            return acc;
        }, []);
    } else if (!isEmpty(selectedClustersHash)) {
        itemData = reduce(selectedClustersHash, (acc, _, id) => {
            if (clusterInfoById[id]) {
                acc.push(clusterInfoById[id]);
            }
            return acc;
        }, []);
    }

    const countSum = itemData.reduce((acc, { count }) => isFinite(count) ? acc + count : acc, 0);
    const headerProps = { countSum, numberOfCellTypes: itemData.length };

    const infoBoxProps = {
        wrapperClassName: 'scatter-plot__info-box',
        itemData,
        itemKeyAccessor: 'id',
        wrapperStyle: { maxHeight: height - 100 }, // 100 is a magic number to account for footer space.
    };

    if (colorByFeature === COLOR_BY_OPTIONS.GENE_EXPRESSION && showColorScaleBar) {
        const maxExpression = itemData.reduce((acc, { clusterMaxValue }) => isFinite(clusterMaxValue) ? Math.max(acc, clusterMaxValue) : acc, 0);
        // compute weighted average across items
        const itemsAverage = weightedAverage(itemData);
        const expressionItemData = itemData.map(item => assign({}, item, { width: 200, height: 25 }));
        const geneExpressionProps = {
            itemClassName: 'scatter-plot__info-box-item--gene-expression',
            itemComponent: CellExpressionHistogram,
            headerProps: assign(headerProps, { maxExpression, colorByFeatureValue, itemsAverage }),
            headerComponent: CellExpressionHeader,
            itemData: expressionItemData,
        };

        return assign(infoBoxProps, geneExpressionProps);
    } else {
        const cellTypeProps = {
            itemClassName: 'scatter-plot__info-box-item--cell-type',
            itemComponent: CellCountItem,
            headerProps,
            headerComponent: CellCountHeader,
        };

        return assign(infoBoxProps, cellTypeProps);
    }
};

const ScatterPlotContainer = () => {
    const windowWidth = useWindowWidth();
    const windowHeight = useWindowHeight();
    const scatterPlotContainerRef = useRef(null);
    const scatterPlotPointDataFetchState = useSelector(getScatterPlotPointDataFetchState);
    const scatterPlotGeneExpressionFetchState = useSelector(getScatterPlotGeneExpressionFetchState);
    const markerGeneFetchState = useSelector(state => state.fetch.markerGeneFetchState);
    const scatterPlotPointData = useSelector(getScatterPlotPointData);
    const scatterPlotPointDataAlreadyFetched = useSelector(hasScatterPlotPointData);
    const scatterData = useSelector(getScatterData);
    const selectedDatasetName = useSelector(getSelectedDatasetName);
    const selectedDatasetScatterPlotName = useSelector(getSelectedDatasetScatterPlotName);
    const dispatch = useDispatch();
    const colorByFeature = useSelector(getColorByFeature);
    const colorByFeatureValue = useSelector(getColorByFeatureValue);
    const selectedClustersHash = useSelector(getSelectedLeafIdMap);
    const scatterPlotGeneExpression = useSelector(getScatterPlotGeneExpression);
    const clusterInfoById = useSelector(getClusterInfoById);
    const colorScale = useSelector(getColorScale);
    const viewState = useSelector(getViewState);
    const getNodeFromId = useSelector(getNodeByFunction);
    const [hoveredClusters, setHoveredClusters] = useState([]);

    const scatterPlotTop = scatterPlotContainerRef.current && scatterPlotContainerRef.current.getBoundingClientRect().top;
    const okToFetchScatterPlotPointData = scatterPlotPointDataFetchState === FETCH_STATE.INIT && !scatterPlotPointDataAlreadyFetched;
    const scatterPlotPointDataIsFetching = scatterPlotPointDataFetchState === FETCH_STATE.FETCHING;
    const okToFetchScatterPlotGeneExpression = (
        colorByFeature === COLOR_BY_OPTIONS.GENE_EXPRESSION &&
        scatterPlotGeneExpressionFetchState === FETCH_STATE.INIT &&
        markerGeneFetchState === FETCH_STATE.COMPLETE
    );
    const scatterPlotGeneExpressionIsFetching = scatterPlotGeneExpressionFetchState === FETCH_STATE.FETCHING;

    const showLoading = !scatterPlotPointData || scatterPlotPointDataIsFetching || scatterPlotGeneExpressionIsFetching;
    const showInfoBox = !isEmpty(hoveredClusters) || !isEmpty(selectedClustersHash);
    const showColorScaleBar = colorByFeature === COLOR_BY_OPTIONS.GENE_EXPRESSION && scatterPlotGeneExpression;

    useEffect(
        () => {if (okToFetchScatterPlotPointData) fetchScatterPlotPointData(dispatch, selectedDatasetName, selectedDatasetScatterPlotName);},
        [dispatch, selectedDatasetName, selectedDatasetScatterPlotName, okToFetchScatterPlotPointData]
    );

    useEffect(
        () => {if (okToFetchScatterPlotGeneExpression) fetchScatterPlotGeneExpression(dispatch, selectedDatasetName, colorByFeatureValue);},
        [dispatch, selectedDatasetName, colorByFeatureValue, okToFetchScatterPlotGeneExpression]
    );

    const selectedClustersHashRef = useReferenceWrapper(selectedClustersHash);

    const selectClusterCB = useCallback(
        (clusterId, { srcEvent }) => {
            const isMultiSelect = srcEvent.shiftKey || srcEvent.metaKey || srcEvent.ctrlKey;

            // Select cluster id only if there are no other clusters selected or user is selecting an additional node.
            // Provides deselect functionality.
            if (clusterId && (isEmpty(selectedClustersHashRef.current) || isMultiSelect)) {
                const node = getNodeFromId(clusterId, TAXONOMY_ID_ACCESSOR);
                dispatch(changeClusterSelection(node, selectedDatasetName));
            } else {
                dispatch(removeClusterSelection(selectedDatasetName));
            }
        },
        [dispatch, selectedDatasetName, selectedClustersHashRef, getNodeFromId]
    );

    const hoveredClusterCB = useCallback(
        (clusterIds) => {
            setHoveredClusters(clusterIds);
        },
        []
    );

    const viewStateChangeCB = useCallback(
        (nextViewState) => {
            dispatch(updateViewStateAction(nextViewState, selectedDatasetName));
        },
        [dispatch, selectedDatasetName]
    );

    // height for the ScatterPlot plot canvas
    const computedHeight = windowHeight - scatterPlotTop;

    const infoBoxProps = useMemo(
        () => !showLoading && computeInfoBoxProps(selectedClustersHash, hoveredClusters, clusterInfoById, computedHeight, colorByFeature, colorByFeatureValue, showColorScaleBar),
        [showLoading, selectedClustersHash, hoveredClusters, clusterInfoById, computedHeight, colorByFeature, colorByFeatureValue, showColorScaleBar]
    );

    return (
        <div className='scatter-plot__container' ref={scatterPlotContainerRef}>
            {showLoading ? (
                <Loader />
            ) : (
                <ScatterPlotChart
                    scatterData={scatterData}
                    height={computedHeight}
                    hoveredClusterCB={hoveredClusterCB}
                    hoveredClusters={hoveredClusters}
                    pointData={scatterPlotPointData}
                    selectClusterCB={selectClusterCB}
                    selectedClustersHash={selectedClustersHash}
                    viewStateChangeCB={viewStateChangeCB}
                    topOffset={scatterPlotTop}
                    width={windowWidth}
                    viewState={viewState}
                >
                    {showColorScaleBar && (
                        <div className='scatter-plot__scalebar'>
                            <ScaleBarContainer
                                colorScheme={colorScale}
                                minValue={ZERO}
                                maxValue={scatterPlotGeneExpression.expressionMax}
                                precision={2}
                                width={400}
                                height={10}
                            />
                        </div>
                    )}
                    {showInfoBox && <InfoBox {...infoBoxProps} />}
                </ScatterPlotChart>
            )}
        </div>
    );
};

export default ScatterPlotContainer;
