import type { FetchResult, NextLink, ObservableSubscription, Operation } from '@apollo/client';
import { ApolloLink, Observable } from '@apollo/client';
import { createOperation } from '@apollo/client/link/utils';
import { argumentsObjectFromField } from '@apollo/client/utilities';
import { Kind, visit } from 'graphql';
import type { FieldNode, GraphQLError, OperationDefinitionNode } from 'graphql';
import { get, set } from 'lodash';
import type { SubscriptionObserver } from '../types';
import { FetchAllResultFragment } from './fragments.gql';

interface DirectiveArgValues {
	pageArgument: string;
	maxConcurrentRequests: number;
	listName: string;
}

export const pageVariableName = '_fetchAllPage';

const createPageOperation = (operation: Operation, pageNum: number): Operation =>
	createOperation(operation.getContext(), {
		...operation,
		variables: { ...operation.variables, [pageVariableName]: pageNum },
	});

const handleAllComplete = (
	results: FetchResult[],
	resultPath: string,
	subscribers: (SubscriptionObserver<FetchResult> | null)[],
): void => {
	const finalResultList: string[] = [];
	const finalErrorList: GraphQLError[] = [];
	results.forEach((result) => {
		const resultList = get(result.data, resultPath);
		if (resultList != null) {
			finalResultList.push(...resultList);
		}
		if (result.errors) {
			finalErrorList.push(...result.errors);
		}
	});

	const finalResult = { ...results[0] };
	if (finalResultList.length > 0) {
		set(finalResult, `data.${resultPath}`, finalResultList);
	}
	if (finalErrorList.length > 0) {
		finalResult.errors = finalErrorList;
	}

	subscribers.forEach((subscriber) => {
		if (subscriber != null) {
			subscriber.next(finalResult);
			subscriber.complete();
		}
	});
};

const createFetchAllObservable = (
	operation: Operation,
	forward: NextLink,
	resultName: string | undefined,
	directiveArgValues: DirectiveArgValues,
): Observable<FetchResult> => {
	const { listName, maxConcurrentRequests } = directiveArgValues;
	const totalPagesPath = `${resultName}.meta.totalPages`;
	const resultPath = `${resultName}.${listName}`;

	const subscribers: (SubscriptionObserver<FetchResult> | null)[] = [];
	let subscriberCount = 0;
	const results: FetchResult[] = [];
	const pageSubscriptions: ObservableSubscription[] = [];
	let totalPages = 0;
	let pagesComplete = 0;

	const handlePageSubscriptionError = (err: Error) => {
		pageSubscriptions.forEach((pageSubscription) => {
			pageSubscription.unsubscribe();
		});
		subscribers.forEach((subscriber) => {
			if (subscriber != null) {
				subscriber.error(err);
			}
		});
	};

	const addPageSubscription = () => {
		const index = pageSubscriptions.length;
		pageSubscriptions[index] = forward(createPageOperation(operation, index + 1)).subscribe({
			next: (data) => {
				results[index] = data;
				totalPages = get(data.data, totalPagesPath, totalPages);
			},
			complete: () => {
				pagesComplete += 1;

				if (pagesComplete === totalPages) {
					handleAllComplete(results, resultPath, subscribers);
				} else if (pageSubscriptions.length < totalPages) {
					addPageSubscription();
				}
			},
			error: handlePageSubscriptionError,
		});
	};

	pageSubscriptions[0] = forward(createPageOperation(operation, 1)).subscribe({
		next: (data) => {
			results[0] = data;
			totalPages = get(data.data, totalPagesPath, totalPages);
		},
		complete: () => {
			pagesComplete += 1;

			// pagesComplete can be greater than totalPages if totalPages is 0
			if (pagesComplete >= totalPages) {
				handleAllComplete(results, resultPath, subscribers);
			} else {
				const concurrentRequests = Math.min(totalPages - 1, maxConcurrentRequests);
				for (let i = 0; i < concurrentRequests; i += 1) {
					addPageSubscription();
				}
			}
		},
		error: handlePageSubscriptionError,
	});

	return new Observable((subscriber) => {
		const subscriberIndex = subscribers.push(subscriber) - 1;
		subscriberCount += 1;
		return () => {
			// comment from apollo-link-retry source code:
			// Note that we are careful not to change the order or length of the array,
			// as we are often mid-iteration when calling this method.
			subscribers[subscriberIndex] = null;
			subscriberCount -= 1;
			if (subscriberCount === 0) {
				pageSubscriptions.forEach((pageSubscription) => {
					pageSubscription.unsubscribe();
				});
			}
		};
	});
};

const fetchAllLink = new ApolloLink((operation, forward): Observable<FetchResult> => {
	let hasFetchAllDirective = false;
	let directiveArgValues: DirectiveArgValues | undefined;
	let resultName = '';
	let queryError: Error | undefined;
	const transformedQuery = visit(operation.query, {
		enter: {
			OperationDefinition(node): OperationDefinitionNode | undefined {
				// check if this operation has a fetchAll directive
				if (node.directives?.some((directive) => directive.name.value === 'fetchAll')) {
					if (node.selectionSet.selections.length > 1) {
						queryError = new Error(
							'Operations with the fetchAll directive cannot query more than one top level field.',
						);
					}

					if (node.selectionSet.selections[0].kind !== Kind.FIELD) {
						queryError = new Error(
							'The fetchAll directive cannot be used with fragments on Query.',
						);
					}

					return {
						...node,
						// add the page variable to the operation
						variableDefinitions: [
							// apollo client returns [] instead of undefined for node.variableDefinitions when no variables are defined
							...(node.variableDefinitions ?? []),
							{
								kind: Kind.VARIABLE_DEFINITION,
								type: {
									kind: Kind.NON_NULL_TYPE,
									type: { kind: Kind.NAMED_TYPE, name: { kind: Kind.NAME, value: 'Int' } },
								},
								variable: {
									kind: Kind.VARIABLE,
									name: { kind: Kind.NAME, value: pageVariableName },
								},
							},
						],
					};
				}
				// No fetchAll directive; do nothing.
				return undefined;
			},
			Directive(node, key, parent, path, ancestors) {
				if (node.name.value === 'fetchAll') {
					hasFetchAllDirective = true;

					// this will always be OperationDefinitionNode because fetchAll is only valid on `QUERY`
					const ancestor = ancestors[ancestors.length - 1] as OperationDefinitionNode;

					// get the name of the top level field, use alias if available, but fallback to name
					// this will always be FieldNode because a selection of any other kind would have resulted in an error
					const topLevelField = ancestor.selectionSet?.selections[0] as FieldNode;
					resultName = topLevelField.alias?.value ?? topLevelField.name.value;

					// parse directive arguments and apply default values for unspecified arguments
					directiveArgValues = {
						pageArgument: 'page',
						maxConcurrentRequests: 2,
						listName: resultName,
						...argumentsObjectFromField(node, operation.variables),
					};

					// Remove fetchAll directive so it does not get sent to the server
					return null;
				}

				// This directive is not fetchAll. Do nothing.
				return undefined;
			},
			Field(node, key, parent, path, ancestors) {
				const parentNode = ancestors[ancestors.length - 2];
				// Check if this is the top-level field of an operation with a fetchAll directive
				if (
					'directives' in parentNode &&
					parentNode.directives?.some((directive) => directive.name.value === 'fetchAll')
				) {
					return {
						...node,
						// Add the page argument with the value set to the page variable
						arguments: [
							// apollo client returns [] instead of undefined for node.variableDefinitions when no variables are defined
							...(node.arguments ?? []),
							{
								kind: Kind.ARGUMENT,
								name: { kind: Kind.NAME, value: directiveArgValues?.pageArgument },
								value: { kind: Kind.VARIABLE, name: { kind: Kind.NAME, value: pageVariableName } },
							},
						],
						// Inject totalPages into the query
						selectionSet: {
							...node.selectionSet,
							selections: [
								...(node.selectionSet?.selections ?? []),
								{
									kind: Kind.FRAGMENT_SPREAD,
									name: { kind: Kind.NAME, value: 'FetchAllResultFragment' },
								},
							],
						},
					};
				}
				// This is not the top-level field of an operation with a fetchAll directive. Do nothing.
				return undefined;
			},
		},
		leave: {
			Document(node) {
				if (hasFetchAllDirective) {
					return {
						...node,
						// Add the FetchAllResultFragment definition to the query document
						definitions: [...node.definitions, ...FetchAllResultFragment.definitions],
					};
				}
				// No fetchAll directive; do nothing.
				return undefined;
			},
		},
	});

	if (queryError) {
		return new Observable((subscriber) => {
			subscriber.error(queryError);
		});
	}

	if (hasFetchAllDirective && directiveArgValues != null) {
		return createFetchAllObservable(
			createOperation(operation.getContext(), { ...operation, query: transformedQuery }),
			forward,
			resultName,
			directiveArgValues,
		);
	}

	// No fetchAll directive. Just pass the operation along to the next link without any modification.
	return forward(operation);
});

export default fetchAllLink;
