import { FoundNodeInfo } from '@common/prosemirror/model/node';
import { resolvedPosFind } from '@common/prosemirror/model/resolved-pos';
import { ResolvedPos } from 'prosemirror-model';
import { EditorState } from 'prosemirror-state';
import { CellSelection, Rect, selectedRect } from 'prosemirror-tables';

/**
 * Gets the surrounding table node info.
 */
export function tableNode($pos: ResolvedPos): FoundNodeInfo | null {
  const found = resolvedPosFind($pos, (node) => node.type === node.type.schema.nodes.table);
  if (!found) return null;
  return { node: found.node, pos: found.$nodePos.pos };
}

/**
 * Gets the selected top most table nodes.
 * That means if all table rows and columns are selected it returns the table node,
 * if some rows are selected it returns row nodes,
 * if some columns are selected it returns column nodes,
 * otherwise it returns selected cells.
 */
export function selectedTableNodes(state: EditorState): FoundNodeInfo[] {
  const selection = state.selection;
  if (!(selection instanceof CellSelection)) return [];

  if (selection.isColSelection()) {
    const rect = selectedRect(state);
    const table = tableNode(selection.$anchorCell);

    const selectedColumns: FoundNodeInfo[] = [];
    const columnCount = getTableColumns(table, rect, selectedColumns);
    // if dimensions has the same column count return columns otherwise return cells
    if (columnCount === rect.map.width) return selectedColumns;
  } else if (selection.isRowSelection()) {
    const rows: FoundNodeInfo[] = [];
    // get the first selected cell
    const $start = selection.$anchorCell.min(selection.$headCell);
    // get it row position
    const start = $start.before($start.depth);
    // iterate selected rows
    selection.content().content.descendants((node, pos) => {
      rows.push({ node: node, pos: pos + start });
      // do not iterate the row children
      return false;
    });
    return rows;
    // some table columns are selected
  }
  return selectedCells(state);
}

/**
 * Adds columns within the selected rectangle to the selected columns list
 * @param columnParent the columns' parent node.
 * @param rect the selected columns.
 * @param selectedColumns the list of selected serialized columns to return.
 * @param colIndex the iterated index accounting for span widths.
 * @returns the updated colIndex to give to the parent after iterating through children.
 */
export function getTableColumns(columnParent: FoundNodeInfo, rect: Rect, selectedColumns: FoundNodeInfo[], colIndex: number = 0): number {
  columnParent.node.forEach((column, offset, index) => {
    if ((column.type.name !== 'colgroup' && column.type.name !== 'col')) {
      return;
    }

    // Add one to move the position inside the parent
    const colPos = columnParent.pos + 1 + offset;

    if (column.childCount) {
      colIndex = getTableColumns({ node: column, pos: colPos }, rect, selectedColumns, colIndex);
    }
    else {
      const colSpan = Number(column.attrs?.['span'] ?? 1);
      if (colIndex + (colSpan - 1) >= rect.left && colIndex < rect.right) {
        selectedColumns.push({ node: column, pos: colPos });
      }
      colIndex += colSpan;
    }
  });

  return colIndex;
}

/**
 * Gets the selected cell nodes.
 */
export function selectedCells(state: EditorState): FoundNodeInfo[] {
  const cells: FoundNodeInfo[] = [];
  const selection = state.selection;
  if (selection instanceof CellSelection) {
    for (let index = 0; index < selection.ranges.length; index++) {
      const $pos = selection.ranges[index].$from;
      cells.push({ node: $pos.parent, pos: $pos.pos - 1 });
    }
  }
  // sort cells by their positions in the document, since selection.ranges are not sorted
  return cells.sort((x, y) => x.pos - y.pos);
}
