import type { FetchResult, NextLink, ObservableSubscription, Operation } from '@apollo/client';
import { ApolloLink, Observable } from '@apollo/client';
import { createOperation } from '@apollo/client/link/utils';
import { argumentsObjectFromField } from '@apollo/client/utilities';
import type { SubscriptionObserver } from '@storis/app_common.graphql';
import type { FieldNode, GraphQLError, OperationDefinitionNode } from 'graphql';
import { Kind, visit } from 'graphql';
import { get, set } from 'lodash';

interface DirectiveArgValues {
	idVariableName: string;
	maxConcurrentRequests: number;
	chunkSize: number;
}

const createChunkOperation = (operation: Operation, idVariableName: string, ids: string[]) =>
	createOperation(operation.getContext(), {
		...operation,
		variables: { ...operation.variables, [idVariableName]: ids },
	});

const handleAllComplete = (
	results: FetchResult[],
	resultPath: string,
	subscribers: (SubscriptionObserver<FetchResult> | null)[],
): void => {
	const finalResultList: unknown[] = [];
	const finalErrorList: GraphQLError[] = [];
	results.forEach((result) => {
		const resultList = get(result.data, resultPath);
		if (resultList != null) {
			finalResultList.push(...resultList);
		}
		if (result.errors) {
			finalErrorList.push(...result.errors);
		}
	});

	const finalResult = { ...results[0] };
	if (finalResultList.length > 0) {
		set(finalResult, `data.${resultPath}`, finalResultList);
	}
	if (finalErrorList.length > 0) {
		finalResult.errors = finalErrorList;
	}

	subscribers.forEach((subscriber) => {
		if (subscriber != null) {
			subscriber.next(finalResult);
			subscriber.complete();
		}
	});
};

const createFetchAllObservable = (
	operation: Operation,
	forward: NextLink,
	resultName: string,
	directiveArgValues: DirectiveArgValues,
): Observable<FetchResult> => {
	const { idVariableName, chunkSize, maxConcurrentRequests } = directiveArgValues;
	const resultPath = resultName;
	const subscribers: (SubscriptionObserver<FetchResult> | null)[] = [];
	let subscriberCount = 0;
	const results: FetchResult[] = [];
	const chunkSubscriptions: ObservableSubscription[] = [];
	const idChunks: string[][] = [];
	let chunksComplete = 0;

	const handlePageSubscriptionError = (err: Error) => {
		subscribers.forEach((subscriber) => {
			if (subscriber != null) {
				subscriber.error(err);
			}
		});
	};

	const addChunkSubscription = () => {
		const index = chunkSubscriptions.length;
		chunkSubscriptions[index] = forward(
			createChunkOperation(operation, idVariableName, idChunks[index]),
		).subscribe({
			next: (data) => {
				results[index] = data;
			},
			complete: () => {
				chunksComplete += 1;

				if (chunksComplete === idChunks.length) {
					handleAllComplete(results, resultPath, subscribers);
				} else if (chunkSubscriptions.length < idChunks.length) {
					addChunkSubscription();
				}
			},
			error: handlePageSubscriptionError,
		});
	};

	const ids: string[] = operation.variables[idVariableName];
	for (let i = 0; i < ids.length; i += chunkSize) {
		idChunks.push(ids.slice(i, i + chunkSize));
	}

	const concurrentRequests = Math.min(idChunks.length, maxConcurrentRequests);
	for (let i = 0; i < concurrentRequests; i += 1) {
		addChunkSubscription();
	}

	return new Observable((subscriber) => {
		const subscriberIndex = subscribers.push(subscriber) - 1;
		subscriberCount += 1;
		return () => {
			// comment from apollo-link-retry source code:
			// Note that we are careful not to change the order or length of the array,
			// as we are often mid-iteration when calling this method.
			subscribers[subscriberIndex] = null;
			subscriberCount -= 1;
			if (subscriberCount === 0) {
				chunkSubscriptions.forEach((chunkSubscription) => {
					chunkSubscription.unsubscribe();
				});
			}
		};
	});
};

const fetchAllByIdsLink = new ApolloLink(
	(operation: Operation, forward: NextLink): Observable<FetchResult> => {
		let hasFetchAllByIdsDirective = false;
		let directiveArgValues: DirectiveArgValues | undefined;
		let resultName = '';
		let queryError: Error | undefined;

		const transformedQuery = visit(operation.query, {
			enter: {
				OperationDefinition(node): OperationDefinitionNode | undefined {
					// check if this operation has a fetchAllByIds directive
					if (node.directives?.some((directive) => directive.name.value === 'fetchAllByIds')) {
						if (node.selectionSet.selections.length > 1) {
							queryError = new Error(
								'Operations with the fetchAllByIds directive cannot query more than one top level field.',
							);
						}

						if (node.selectionSet.selections[0].kind !== Kind.FIELD) {
							queryError = new Error(
								'The fetchAllByIds directive cannot be used with fragments on Query.',
							);
						}

						return { ...node };
					}
					// No fetchAllByIds directive; do nothing.
					return undefined;
				},
				Directive(node, key, parent, path, ancestors) {
					if (node.name.value === 'fetchAllByIds') {
						hasFetchAllByIdsDirective = true;

						// this will always be OperationDefinitionNode because fetchAllByIds is only valid on `QUERY`
						const ancestor = ancestors[ancestors.length - 1] as OperationDefinitionNode;

						// get the name of the top level field, use alias if available, but fallback to name
						// this will always be FieldNode because a selection of any other kind would have resulted in an error
						const topLevelField = ancestor.selectionSet?.selections[0] as FieldNode;
						resultName = topLevelField.alias?.value ?? topLevelField.name.value;

						// parse directive arguments and apply default values for unspecified arguments
						directiveArgValues = {
							idVariableName: 'id',
							maxConcurrentRequests: 10,
							chunkSize: 30,
							...argumentsObjectFromField(node, operation.variables),
						};

						// Remove fetchAllByIds directive so it does not get sent to the server
						return null;
					}

					// This directive is not fetchAllByIds. Do nothing.
					return undefined;
				},
			},
		});

		if (queryError) {
			return new Observable((subscriber) => {
				subscriber.error(queryError);
			});
		}

		if (hasFetchAllByIdsDirective && directiveArgValues != null) {
			return createFetchAllObservable(
				createOperation(operation.getContext(), { ...operation, query: transformedQuery }),
				forward,
				resultName,
				directiveArgValues,
			);
		}

		// No fetchAllByIds directive. Just pass the operation along to the next link without any modification.
		return forward(operation);
	},
);

export default fetchAllByIdsLink;
