import { worldToLocal } from "@/alignment-tool/utils/alignment-transform";
import { GUID, assert } from "@faro-lotv/foundation";
import {
  CachedWorldTransform,
  computeCachedWorldTransform,
} from "@faro-lotv/project-source";
import {
  CaptureTreeEntity,
  RevisionScanEntity,
} from "@faro-lotv/service-wires";
import { Matrix4, Quaternion, Vector3 } from "three";
import { EntityMap } from "./revision-slice";

/** Matrix to convert from a Z-up to a Y-up coordinate system. */
const Z_TO_Y_UP: Readonly<Matrix4> = Object.freeze(
  new Matrix4().makeRotationFromQuaternion(
    new Quaternion().setFromAxisAngle(new Vector3(1, 0, 0), -Math.PI / 2),
  ),
);

export type RevisionTransformCache = Record<
  GUID,
  CachedWorldTransform | undefined
>;

/** Transforms to overwrite the persisted local transforms of entities */
export type EntityTransformOverrides = Record<
  GUID,
  RevisionScanEntity["pose"] | undefined
>;

type ChildrenMap = Record<GUID, GUID[] | undefined>;

/**
 * @param entityMap All available revision entities.
 * @param transformOverrides The transform overrides to consider in the calculation
 * @returns A cache containing the world transforms for the revision entities.
 */
export function generateRevisionTransformCache(
  entityMap: EntityMap,
  transformOverrides: EntityTransformOverrides,
): RevisionTransformCache {
  let root;

  for (const entity of Object.values(entityMap)) {
    if (!entity?.parentId) {
      root = entity;
      break;
    }
  }

  assert(root, "No root entity found in the revision");

  return updateRevisionTransformCache(
    {},
    entityMap,
    root.id,
    transformOverrides,
  );
}

/**
 * Updates a RevisionTransformCache starting from an element within, skipping all elements outside of its subtree.
 *
 * @param cache The cache to update
 * @param entityMap All available revision entities
 * @param startFromId The entity to start updating from
 * @param transformOverrides The transform overrides to consider in the calculation
 * @returns the updated input cache
 */
export function updateRevisionTransformCache(
  cache: RevisionTransformCache,
  entityMap: EntityMap,
  startFromId: GUID,
  transformOverrides: EntityTransformOverrides,
): RevisionTransformCache {
  // At the root, we convert from Z-up to Y-up as we use Y-up in the viewer
  let parentTransform = Z_TO_Y_UP;

  const startFrom = entityMap[startFromId];

  assert(startFrom, "Expected startFrom entity to be in entity map");

  if (startFrom.parentId) {
    const parentCachedTransform = cache[startFrom.parentId];

    assert(
      parentCachedTransform,
      "Parent transform of entity not found in cache",
    );

    parentTransform = new Matrix4().fromArray(
      parentCachedTransform.worldMatrix,
    );
  }

  generateCacheRecursively(
    entityMap,
    generateChildrenMap(entityMap),
    cache,
    parentTransform,
    startFrom,
    transformOverrides,
  );

  return cache;
}

/**
 * @param entityMap All available revision entities.
 * @param childrenMap A map from an entity ID to its children IDs.
 * @param cache The transform cache to modify.
 * @param parentTransform The world transform of the parent entity (should not be modified).
 * @param curEntity The entity to process.
 * @param transformOverrides The active transform overrides
 */
function generateCacheRecursively(
  entityMap: EntityMap,
  childrenMap: ChildrenMap,
  cache: RevisionTransformCache,
  parentTransform: Matrix4,
  curEntity: CaptureTreeEntity,
  transformOverrides: EntityTransformOverrides,
): void {
  const { pos, rot, scale } =
    transformOverrides[curEntity.id] ?? curEntity.pose;

  const localPos = new Vector3(pos?.x ?? 0, pos?.y ?? 0, pos?.z ?? 0);
  const localRot = new Quaternion(
    rot?.x ?? 0,
    rot?.y ?? 0,
    rot?.z ?? 0,
    rot?.w ?? 1,
  );
  const localScale = new Vector3(scale?.x ?? 1, scale?.y ?? 1, scale?.z ?? 1);

  const localTransform = new Matrix4().compose(localPos, localRot, localScale);
  const globalTransform = localTransform.premultiply(parentTransform);

  cache[curEntity.id] = computeCachedWorldTransform(globalTransform);

  for (const childId of childrenMap[curEntity.id] ?? []) {
    const child = entityMap[childId];

    if (child) {
      generateCacheRecursively(
        entityMap,
        childrenMap,
        cache,
        globalTransform,
        child,
        transformOverrides,
      );
    }
  }
}

/**
 * The Capture Tree API does not provide the childrenIds directly, so it needs to be computed on the fly.
 *
 * @param entityMap The entities in the revision.
 * @returns A map from an entity ID to its children IDs.
 */
function generateChildrenMap(entityMap: EntityMap): ChildrenMap {
  const childrenMap: ChildrenMap = {};

  for (const entity of Object.values(entityMap)) {
    if (entity?.parentId) {
      const childrenList = childrenMap[entity.parentId] ?? [];
      childrenList.push(entity.id);
      childrenMap[entity.parentId] = childrenList;
    }
  }

  return childrenMap;
}

/**
 * @returns the local pose for an entity calculated from their world transforms in the three.js scene
 * @param worldTransform the three.js world transform of the entity
 * @param parentWorldTransform the three.js world transform of the entities parent, or undefined if the entity is the root
 */
export function computeLocalEntityPoseFromWorldTransforms(
  worldTransform: Matrix4,
  // At the root, we need to revert the Z-up to Y-up conversion from the transform cache calculation
  parentWorldTransform: Matrix4 = Z_TO_Y_UP,
): CaptureTreeEntity["pose"] {
  const localTransform = worldToLocal(parentWorldTransform, worldTransform);

  const position = new Vector3();
  const quaternion = new Quaternion();
  const scale = new Vector3();

  localTransform.decompose(position, quaternion, scale);

  return {
    pos: {
      x: position.x,
      y: position.y,
      z: position.z,
    },
    rot: {
      x: quaternion.x,
      y: quaternion.y,
      z: quaternion.z,
      w: quaternion.w,
    },
    scale: {
      x: scale.x,
      y: scale.y,
      z: scale.z,
    },
  };
}
