import { Alert, Backdrop, CircularProgress } from '@mui/material'
import React from 'react'
import { css } from '@emotion/react'
import { FetchState } from '../utility/api/useCachedFetch'
import { ErrorBoundary } from '../utility/ErrorBoundary'

type ExtractDataTypes<T> = {
  [P in keyof T]: T[P] extends FetchState<infer D>
    ? D
    : T[P] extends undefined
      ? undefined
      : never
}

function getData<T extends Array<FetchState<unknown>>>(
  parents: [...T],
): ExtractDataTypes<T> {
  return parents.map((parent) => parent.data) as any
}

interface CommonProps {
  showSpinner?: boolean
}

type DataLoaderOverload = {
  <TFetchStates extends Array<FetchState<unknown>>>(
    props: {
      fetchState: [...TFetchStates]
      children: (
        data: NonNullable<ExtractDataTypes<TFetchStates>>,
        isLoading: boolean,
      ) => React.ReactNode
      dontRenderWhileLoading: true
    } & CommonProps,
  ): React.ReactElement
  <TFetchStates extends Array<FetchState<unknown>>>(
    props: {
      fetchState: [...TFetchStates]
      children: (
        data: ExtractDataTypes<TFetchStates>,
        isLoading: boolean,
      ) => React.ReactNode
      dontRenderWhileLoading?: false
    } & CommonProps,
  ): React.ReactElement
  <TFetchState extends FetchState<unknown>>(
    props: {
      fetchState: TFetchState
      children: (
        data: NonNullable<TFetchState['data']>,
        isLoading: boolean,
      ) => React.ReactNode
      dontRenderWhileLoading: true
    } & CommonProps,
  ): React.ReactElement
  <TFetchState extends FetchState<unknown>>(
    props: {
      fetchState: TFetchState
      children: (
        data: TFetchState['data'] | undefined,
        isLoading: boolean,
      ) => React.ReactNode
      dontRenderWhileLoading?: false
    } & CommonProps,
  ): React.ReactElement
}

export const DataLoader: DataLoaderOverload = ({
  fetchState,
  children,
  dontRenderWhileLoading,
  showSpinner = true,
}: any) => (
  <DataLoaderInner
    fetchState={Array.isArray(fetchState) ? fetchState : [fetchState]}
    dontRenderWhileLoading={dontRenderWhileLoading}
    showSpinner={showSpinner}
  >
    {(data, isLoading) =>
      children((Array.isArray(fetchState) ? data : data[0]) as any, isLoading)
    }
  </DataLoaderInner>
)

const DataLoaderInner = <TFetchStates extends Array<FetchState<unknown>>>({
  fetchState,
  dontRenderWhileLoading,
  showSpinner,
  children,
}: {
  fetchState: [...TFetchStates]
  dontRenderWhileLoading?: boolean
  showSpinner: boolean
  children: (data: ExtractDataTypes<TFetchStates>, isLoading: boolean) => React.ReactNode
}) => {
  const [remountKey, setRemountKey] = React.useState(0)
  const doRemountOnLoadRef = React.useRef(false)

  // If there's an error, remount the component when the loading state changes
  React.useEffect(() => {
    if (!fetchState.some((s) => s.isLoading) && doRemountOnLoadRef.current) {
      doRemountOnLoadRef.current = false
      setRemountKey((k) => k + 1)
    }
  }, [fetchState])

  return (
    <>
      {showSpinner ? (
        <Backdrop
          css={(theme) => css`
            /* color: ${theme.palette.text.primary}; */
            background-color: transparent;
            z-index: ${theme.zIndex.drawer + 1};
          `}
          open={fetchState.some((s) => s.isLoading)}
        >
          <CircularProgress color="inherit" />
        </Backdrop>
      ) : null}

      {/* Catch errors in children because they may be because of an out-of-date
          cached model which does not match the current API shape.
          Hopefully, any currently in-flight fetches will complete and update the
          content with something that renders correctly. */}
      <ErrorBoundary
        key={remountKey}
        onError={(error, componentStack) => {
          console.error('Error in DataLoader (caught)', error, componentStack)
          if (fetchState.some((s) => s.isLoading)) {
            doRemountOnLoadRef.current = true
          }
        }}
      >
        {fetchState.some((s) => s.error !== undefined) ? (
          <Alert severity="error">
            {fetchState.find((s) => s.error !== undefined)?.error}
          </Alert>
        ) : null}
        {(() => {
          const isLoading = fetchState.some((s) => s.data === undefined)
          if (isLoading && dontRenderWhileLoading) {
            return null
          }
          const data = getData(fetchState)
          return children(data, isLoading)
        })()}
      </ErrorBoundary>
    </>
  )
}
