// Copyright 2020-2024 Luminary Cloud, Inc. All Rights Reserved.
import { Message } from '@bufbuild/protobuf';

import * as ProtoDescriptor from '../ProtoDescriptor';
import { paramGroupDesc } from '../SimulationParamDescriptor';
import * as basepb from '../proto/base/base_pb';
import * as simulationpb from '../proto/client/simulation_pb';
import * as explorationpb from '../proto/exploration/exploration_pb';
import { GeometryTags } from '../recoil/geometry/geometryTagsObject';
import { StaticVolume } from '../recoil/volumes';
import { NamesRecord } from '../state/external/project/simulation/param/boundaryNames';

import { ParamScope, chainParamScopes, createParamScope } from './ParamScope';
import { adVectorToArray } from './Vector';
import { getAdValue, newAdFloat, newScalarAdVector } from './adUtils';
import assert from './assert';
import {
  AnyBoundaryCondition,
  findFluidBoundaryCondition,
  findHeatBoundaryCondition,
  findParentPhysicsByBoundaryConditionId,
} from './boundaryConditionUtils';
import { EMPTY_VALUE } from './constants';
import { findHeatSourceById, findParentPhysicsByHeatSourceId } from './heatSourceUtils';
import { findMaterialEntityById, getMaterialName } from './materialUtils';
import { formatNumber, formatNumberList, fromBigInt } from './number';
import { findParentPhysicsByPhysicalBehaviorId, getAllPhysicalBehaviors } from './physicalBehaviorUtils';
import { findFluidPhysicsMaterial, findPhysicsById, getPhysicsName, isPhysicsFluid, isPhysicsHeat } from './physicsUtils';
import { getBoundaryCondName } from './simulationTree/utils';

/** Copy for previewing an uploaded custom sample table */
export const TABLE_PREVIEW_SUBTITLE = 'Explore your specific data. User-provided samples are ' +
  'essential for capturing the unique characteristics of your experiment and ensuring accurate ' +
  'results.';

export const explorationVariableNodeId = (index: number): string => `Exploration Variable ${index}`;

export function newRealValue(value: basepb.AdFloatType): explorationpb.Value {
  return new explorationpb.Value({ typ: { case: 'real', value } });
}

export function newVector3Value(value: basepb.AdVector3): explorationpb.Value {
  return new explorationpb.Value({ typ: { case: 'vector3', value } });
}

// Interface representing the state of a parameter
interface ParamState {
  param: ProtoDescriptor.Param,
  value: any,
}

// Create a new exploration value based on the param state
export function newExplorationValue(paramState: ParamState) {
  switch (paramState.param.type) {
    case ProtoDescriptor.ParamType.REAL: {
      return newRealValue(newAdFloat(paramState.value as number));
    }
    case ProtoDescriptor.ParamType.VECTOR3: {
      const { x, y, z } = paramState.value as ProtoDescriptor.Vector3;
      return newVector3Value(newScalarAdVector(x, y, z));
    }
    default:
      throw Error(`Incompatible param type ${paramState.param.type}`);
  }
}

export function newVarSpec(
  id: string,
  type: explorationpb.VarType,
  field: string,
  text: string,
): explorationpb.VarSpec {
  return new explorationpb.VarSpec({ id, type, field, text });
}

function isCandidate(param: ProtoDescriptor.Param, isSensitivityAnalysis: boolean) {
  // We only allow real or vector3 types
  return (
    (param.type === ProtoDescriptor.ParamType.REAL) ||
    ((param.type === ProtoDescriptor.ParamType.VECTOR3) && !isSensitivityAnalysis)
  );
}

type Descriptor = {
  proto?: Message,
  group: ProtoDescriptor.ParamGroup,
}

/**
 * Get the param scope for a particular type of exploration var.
 * @param baseline
 * @param proto
 * @param baseParamScope
 * @param varType
 * @param varId
 * @returns paramScope
 */
function getParamScopeForType(
  baseline: simulationpb.SimulationParam,
  proto: Message,
  baseParamScope: ParamScope,
  varType: explorationpb.VarType,
  varId: string,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
) {
  switch (varType) {
    case explorationpb.VarType.BOUNDARY: {
      const physics = findParentPhysicsByBoundaryConditionId(baseline, varId);
      const material = physics &&
        findFluidPhysicsMaterial(baseline, physics, geometryTags, staticVolumes);
      return chainParamScopes([physics, material, proto], [], baseParamScope);
    }
    case explorationpb.VarType.PHYSICAL_BEHAVIOR: {
      const physics = findParentPhysicsByPhysicalBehaviorId(baseline, varId);
      const material = physics &&
        findFluidPhysicsMaterial(baseline, physics, geometryTags, staticVolumes);
      return chainParamScopes([physics, material, proto], [], baseParamScope);
    }
    case explorationpb.VarType.REFERENCE_FRAME:
    case explorationpb.VarType.PARTICLE_GROUP:
    case explorationpb.VarType.MATERIAL:
      return createParamScope(proto, [], baseParamScope);
    case explorationpb.VarType.GLOBAL:
      return baseParamScope;
    case explorationpb.VarType.HEAT_SOURCE: {
      const physics = findParentPhysicsByHeatSourceId(baseline, varId);
      return chainParamScopes([physics, proto], [], baseParamScope);
    }
    case explorationpb.VarType.PHYSICS: {
      const physics = proto as simulationpb.Physics;
      const material =
        findFluidPhysicsMaterial(baseline, physics, geometryTags, staticVolumes);
      return chainParamScopes([material, proto], [], baseParamScope);
    }
    case explorationpb.VarType.INVALID:
    default:
      throw Error('Invalid var type.');
  }
}

export function getProtoForType(
  baseline: simulationpb.SimulationParam,
  varType: explorationpb.VarType,
  id: string,
): Descriptor | undefined {
  switch (varType) {
    case explorationpb.VarType.BOUNDARY: {
      const fluidBc = findFluidBoundaryCondition(baseline, id);
      if (fluidBc) {
        return { proto: fluidBc, group: paramGroupDesc.boundary_conditions_fluid };
      }
      const heatBc = findHeatBoundaryCondition(baseline, id);
      if (heatBc) {
        return { proto: heatBc, group: paramGroupDesc.boundary_conditions_heat };
      }
      return undefined;
    }
    case explorationpb.VarType.PHYSICAL_BEHAVIOR:
      return {
        proto: getAllPhysicalBehaviors(baseline).find(
          (behavior) => behavior.physicalBehaviorId === id,
        ),
        group: paramGroupDesc.physical_behavior,
      };
    case explorationpb.VarType.REFERENCE_FRAME:
      return {
        proto: baseline.motionData.find((frame) => frame.frameId === id),
        group: paramGroupDesc.motion_data,
      };
    case explorationpb.VarType.MATERIAL: {
      const material = findMaterialEntityById(baseline, id);
      if (material?.material.case === 'materialFluid') {
        return {
          proto: material,
          group: paramGroupDesc.material_fluid,
        };
      }
      if (material?.material.case === 'materialSolid') {
        return {
          proto: material,
          group: paramGroupDesc.material_solid,
        };
      }
      return undefined;
    }
    case explorationpb.VarType.GLOBAL:
      return { proto: baseline, group: paramGroupDesc.simulation_param };
    case explorationpb.VarType.PHYSICS: {
      const physics = findPhysicsById(baseline, id);
      if (physics && isPhysicsFluid(physics)) {
        return { proto: physics, group: paramGroupDesc.fluid };
      }
      if (physics && isPhysicsHeat(physics)) {
        return { proto: physics, group: paramGroupDesc.heat };
      }
      return undefined;
    }
    case explorationpb.VarType.PARTICLE_GROUP:
      return {
        proto: baseline.particleGroup.find((group) => group.particleGroupId === id),
        group: paramGroupDesc.particle_group,
      };
    case explorationpb.VarType.HEAT_SOURCE:
      return { proto: findHeatSourceById(baseline, id), group: paramGroupDesc.heat_source };
    case explorationpb.VarType.INVALID:
    default:
      throw Error('Invalid var type.');
  }
}

// Keep this in sync with GetProtoName() in go/core/exploration/utils.go
export function findNameByType(
  param: simulationpb.SimulationParam,
  id: string,
  bcNames: NamesRecord,
  varType: explorationpb.VarType,
) {
  const desc = getProtoForType(param, varType, id);
  assert(!!desc && !!desc.proto, `Proto containing exploration var ${varType} ${id} not found`);
  switch (varType) {
    case explorationpb.VarType.BOUNDARY: {
      const physics = findParentPhysicsByBoundaryConditionId(param, id);
      const bcName = getBoundaryCondName(bcNames, (desc.proto as AnyBoundaryCondition));
      return `${getPhysicsName(physics!, param)}/${bcName}`;
    }
    case explorationpb.VarType.PHYSICAL_BEHAVIOR: {
      const physics = findParentPhysicsByPhysicalBehaviorId(param, id);
      const behaviorName = (desc.proto as simulationpb.PhysicalBehavior).physicalBehaviorName;
      return `${getPhysicsName(physics!, param)}/${behaviorName}`;
    }
    case explorationpb.VarType.REFERENCE_FRAME:
      return (desc.proto as simulationpb.MotionData).frameName;
    case explorationpb.VarType.PARTICLE_GROUP:
      return (desc.proto as simulationpb.ParticleGroup).particleGroupName;
    case explorationpb.VarType.PHYSICS:
      return getPhysicsName(desc.proto as simulationpb.Physics, param);
    case explorationpb.VarType.MATERIAL:
      return getMaterialName(desc.proto as simulationpb.MaterialEntity, param);
    case explorationpb.VarType.GLOBAL:
      return '';
    case explorationpb.VarType.HEAT_SOURCE: {
      const physics = findParentPhysicsByHeatSourceId(param, id);
      const heatSourceName = (desc.proto as simulationpb.HeatSource).heatSourceName;
      return `${getPhysicsName(physics!, param)}/${heatSourceName}`;
    }
    default:
      throw Error('Undefined var type.');
  }
}

export function getLabel(
  param: simulationpb.SimulationParam,
  bcNames: NamesRecord,
  varSpec: explorationpb.VarSpec,
) {
  const { id, type, text } = varSpec;
  let fieldText = text;
  if (id.length) {
    const name = findNameByType(param, id, bcNames, type);
    if (name.length) {
      fieldText = `${name}/${fieldText}`;
    }
  }
  return fieldText;
}

// Get enabled parameters for a certain variable type and index.
function getCandidatesInScope(
  varType: explorationpb.VarType,
  varId: string,
  paramGroup: ProtoDescriptor.ParamGroup,
  isSensitivityAnalysis: boolean,
  paramScope: ParamScope,
  otherVars: explorationpb.Var[],
) {
  return paramScope.enabledParams(paramGroup, true).filter(
    (param) => isCandidate(param, isSensitivityAnalysis) &&
      // Filter out params that are already used in other exploration vars
      !otherVars.find(
        (otherVar) => (
          otherVar.spec?.field === param.name &&
          otherVar.spec.id === varId &&
          otherVar.spec.type === varType
        ),
      ),
  );
}

// Get the all parameters that are available as exploration variables for a certain variable type
export function getCandidates(
  varType: explorationpb.VarType,
  varId: string,
  baseline: simulationpb.SimulationParam,
  paramScope: ParamScope,
  isSensitivityAnalysis: boolean,
  otherVars: explorationpb.Var[],
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
) {
  const desc = getProtoForType(baseline, varType, varId);
  assert(!!desc?.proto, `Proto containing exploration var ${varType} ${varId} not found`);
  const scope = getParamScopeForType(
    baseline,
    desc.proto,
    paramScope,
    varType,
    varId,
    geometryTags,
    staticVolumes,
  );
  return getCandidatesInScope(varType, varId, desc.group, isSensitivityAnalysis, scope, otherVars);
}

// Create a param state from a variable spec
export function varSpecToParamState(
  varSpec: explorationpb.VarSpec,
  baseline: simulationpb.SimulationParam,
  paramScope: ParamScope,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
): ParamState | null {
  const { id, type, field } = varSpec;
  const desc = getProtoForType(baseline, type, id);
  if (!desc?.proto) {
    return null;
  }
  const scope = getParamScopeForType(
    baseline,
    desc.proto,
    paramScope,
    type,
    id,
    geometryTags,
    staticVolumes,
  );
  const candidates = getCandidatesInScope(type, id, desc.group, false, scope, []);
  const param = candidates.find((candidate) => candidate.name === field)!;
  if (!param) {
    return null;
  }
  return { param, value: scope.value(param) };
}

export function newExplorationVar(
  baseline: simulationpb.SimulationParam,
  param: ProtoDescriptor.Param,
  paramScope: ParamScope,
  valueType: explorationpb.Var['valueTyp']['case'],
  varType: explorationpb.VarType,
  id: string,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
  /** the name of the table for a TableColumn variable */
  tableName?: string,
): explorationpb.Var {
  const spec = newVarSpec(id, varType, param.name, param.text);
  const variable = new explorationpb.Var({ spec });
  const paramState = varSpecToParamState(spec, baseline, paramScope, geometryTags, staticVolumes)!;
  switch (valueType) {
    case 'range': {
      variable.valueTyp = {
        case: 'range',
        value: new explorationpb.Range({
          min: newExplorationValue(paramState),
          max: newExplorationValue(paramState),
          nSamples: 2,
        }),
      };
      break;
    }
    case 'enumerated': {
      // Default to a single value to avoid having an initial exploration defined with 2 simulations
      // with the same parameters.
      variable.valueTyp = {
        case: 'enumerated',
        value: new explorationpb.Enumerated({
          value: [newExplorationValue(paramState)],
        }),
      };
      break;
    }
    case 'column': {
      variable.valueTyp = {
        case: 'column',
        value: new explorationpb.TableColumn({
          table: tableName ?? '',
        }),
      };
      break;
    }
    default:
      throw Error(`type: ${valueType}`);
  }
  return variable;
}

function formatVector3(vector: basepb.AdVector3): string {
  const { x, y, z } = vector;
  return formatNumberList([getAdValue(x), getAdValue(y), getAdValue(z)]);
}

interface NamedValue {
  name: string;
  value: string;
}

export function extractScalarValue(value?: explorationpb.Value): number | bigint | undefined {
  switch (value?.typ.case) {
    case 'int':
      return value.typ.value;
    case 'real':
      return getAdValue(value.typ.value);
    default:
      return undefined;
  }
}

// Extract and format an exploration value.
export function extractFormattedValue(value: explorationpb.Value): string {
  switch (value.typ.case) {
    case 'int':
      return formatNumber(fromBigInt(value.typ.value));
    case 'real':
      return formatNumber(getAdValue(value.typ.value));
    case 'vector3':
      return formatVector3(value.typ.value);
    default:
      throw Error(`formatValue: illegal value: ${value}`);
  }
}

// Returns true if the two variable names are equal.
export function varNameEquals(a: explorationpb.VarSpec, b: explorationpb.VarSpec): boolean {
  return (a.field === b.field && a.id === b.id && a.type === b.type);
}

// Extract a list of named values from an exploration.
// Must be in sync with GetExplorationVarName from go/core/exploration/utils.go
export function extractValues(
  values: explorationpb.Values,
  exploration: explorationpb.Exploration,
  param: simulationpb.SimulationParam,
  bcNames: NamesRecord,
): NamedValue[] {
  const varList = exploration.var.filter((variable) => !variable.synthetic);
  return values.value.slice(0, varList.length).map((value, index) => {
    const { id, type, text } = varList[index].spec!;
    const name = findNameByType(param, id, bcNames, type);
    const fieldText = name.length ? `${name}/${text}` : text;
    return { name: fieldText, value: extractFormattedValue(value) };
  });
}

// Get the baseline value of a exploration var
export function getBaselineValueForVarSpec(
  varSpec: explorationpb.VarSpec,
  baseline: simulationpb.SimulationParam,
  paramScope: ParamScope,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
) {
  const paramState =
    varSpecToParamState(varSpec, baseline, paramScope, geometryTags, staticVolumes);
  // If no valid state is returned that means the parameter is not active for this simulation.
  // In that case we return the string representing an empty value.
  if (!paramState) {
    return EMPTY_VALUE;
  }
  if (paramState.param.type === ProtoDescriptor.ParamType.VECTOR3) {
    const { x, y, z } = paramState.value as ProtoDescriptor.Vector3;
    const adVec = newScalarAdVector(x, y, z);
    return formatVector3(adVec);
  }
  return formatNumber(paramState.value as number);
}

// For a given variable, return the value of that variable. If it is part of the
// exploration, return the value in the exploration. If not, it is part of the
// baseline and return the value from the baseline.
export function getValueForVarSpec(
  varSpecA: explorationpb.VarSpec,
  jobValues: explorationpb.Value[],
  paramScope: ParamScope,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
  baseline?: simulationpb.SimulationParam,
  exploration?: explorationpb.Exploration,
): string {
  let varIndex = -1;
  if (exploration && exploration.policy.case !== 'sensitivityAnalysis') {
    varIndex = exploration.var.findIndex((variableB) => {
      const varSpecB = variableB.spec;
      return varSpecB && varNameEquals(varSpecA, varSpecB);
    });
  }
  if (varIndex >= 0 && varIndex < jobValues.length) {
    return extractFormattedValue(jobValues[varIndex]);
  }
  if (baseline) {
    return getBaselineValueForVarSpec(
      varSpecA,
      baseline,
      paramScope,
      geometryTags,
      staticVolumes,
    );
  }
  return '';
}

export function validExploration(exploration?: explorationpb.Exploration) {
  return exploration && exploration.var.length > 0;
}

/**
 * Remove any experiment variables that have been disabled by the current parameters.
 */
export function filterExperimentVariables(
  experiment: explorationpb.Exploration,
  simParam: simulationpb.SimulationParam,
  paramScope: ParamScope,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
): explorationpb.Exploration {
  const newExperiment = experiment.clone();
  const newVars = experiment.var.filter((variable) => {
    const { spec } = variable;
    return !!varSpecToParamState(spec!, simParam, paramScope, geometryTags, staticVolumes);
  });
  newExperiment.var = newVars;
  return newExperiment;
}

/**
 * Reset the variables in the exploration
 *
 * If the exploration has the policy case "Custom", then all variables are set to Column selectors.
 * Otherwise, the variables are all set to Range.
 *
 * @param baseline
 * @param paramScope
 * @param exploration
 *
 * @throws an error if the paramState cannot be determined from the variable
 *
 * @returns a modified exploration
 */
export function resetColumnVariables(
  baseline: simulationpb.SimulationParam,
  paramScope: ParamScope,
  exploration: explorationpb.Exploration,
  geometryTags: GeometryTags,
  staticVolumes: StaticVolume[],
) {
  const newVariables = exploration.var.map((variable) => {
    const paramState =
      varSpecToParamState(variable.spec!, baseline, paramScope, geometryTags, staticVolumes);
    if (!paramState) {
      throw Error(`Exploration variable ${variable} ${variable.spec?.id} not found`);
    }
    return newExplorationVar(
      baseline,
      paramState.param,
      paramScope,
      exploration.policy.case === 'custom' ? 'column' : 'range',
      variable.spec!.type,
      variable.spec!.id,
      geometryTags,
      staticVolumes,
    );
  });
  exploration.var = newVariables;
  return exploration;
}

/**
 * Convert from a primal value (normal exploration value without derivatives) to AD value.
 * @param value
 * @param type
 * @returns
 */
export function primalToAd(
  value: explorationpb.Value,
  type: ProtoDescriptor.ParamType, // optional type to enforce
): basepb.AdFloatType | basepb.AdVector3 {
  if (
    (value.typ.case === 'real' && type === ProtoDescriptor.ParamType.REAL) ||
    (value.typ.case === 'vector3' && type === ProtoDescriptor.ParamType.VECTOR3)
  ) {
    return value.typ.value;
  }
  throw Error(`Unsupported value: ${value.toJsonString()}`);
}

/**
 * Convert from an AD value to a normal value without derivatives.
 * @param value
 * @param type
 * @returns
 */
export function adToPrimal(
  value: basepb.AdFloatType | basepb.AdVector3,
  type: ProtoDescriptor.ParamType, // optional type to enforce
): explorationpb.Value {
  if (value instanceof basepb.AdFloatType) {
    if (type === ProtoDescriptor.ParamType.REAL) {
      return newRealValue(value);
    }
  }
  if (value instanceof basepb.AdVector3) {
    if (type === ProtoDescriptor.ParamType.VECTOR3) {
      return newVector3Value(value);
    }
  }
  throw Error(`Unsupported value: ${value.toJsonString()}`);
}

// Returns a display friendly string of variable values when using an interval and full sweep policy
export function getDisplayableIntervalValues(
  proto: explorationpb.Range,
  type: ProtoDescriptor.ParamType,
): string[] {
  const newRange = proto.clone();
  const min = primalToAd(newRange.min!, type);
  const max = primalToAd(newRange.max!, type);
  // If the Range does not have nSample, use nInterval
  const samples = newRange.nSamples || newRange.nInterval + 1;
  // n intervals means n + 1 values (intervals are the gap between values)
  const intervalValues = new Array<string>(samples);

  // All input types are either ParamType.REAL or VECTOR3
  if (type === ProtoDescriptor.ParamType.REAL) {
    // Interpolate between min and max
    const minNumber = getAdValue(min as basepb.AdFloatType);
    const maxNumber = getAdValue(max as basepb.AdFloatType);
    const intervalGap = (maxNumber - minNumber) / (samples - 1);
    for (let i = 0; i < samples; i += 1) {
      intervalValues[i] = formatNumber(minNumber + i * intervalGap);
    }
  } else if (type === ProtoDescriptor.ParamType.VECTOR3) {
    // Linearly interpolate between the two vectors
    const [x1, y1, z1] = adVectorToArray(min as basepb.AdVector3)!;
    const [x2, y2, z2] = adVectorToArray(max as basepb.AdVector3)!;

    const xGap = (x2 - x1) / (samples - 1);
    const yGap = (y2 - y1) / (samples - 1);
    const zGap = (z2 - z1) / (samples - 1);

    for (let i = 0; i < samples; i += 1) {
      const parts = [
        formatNumber(x1 + i * xGap),
        formatNumber(y1 + i * yGap),
        formatNumber(z1 + i * zGap),
      ];
      const displayVector = `[${parts.join(', ')}]`;
      intervalValues[i] = displayVector;
    }
  }
  return intervalValues;
}
