import {ID} from '@wandb/weave/common/util/id';

import * as Types from './types';

export type StateType = {
  [T in Types.ObjType]: {
    [id: string]: Types.PartFromType<T>;
  };
};

interface NormContext {
  viewID: string;
  result: Array<PartWithRef<Types.ObjType>>;
}

export type NormFn<S extends Types.ObjSchema> = (
  whole: S['whole'],
  ctx: NormContext
) => S['part'];

export type FullNormFn<S extends Types.ObjSchema> = (
  whole: S['whole'],
  ctx: NormContext
) => Types.PartRefFromType<S['type']>;

export interface PartWithRef<T extends Types.ObjType> {
  ref: Types.PartRefFromType<T>;
  part: Types.PartFromType<T>;
}

export type NormFunctionMap = {
  [T in Types.ObjType]: FullNormFn<Types.ObjSchemaFromType<T>>;
};

interface DenormContext {
  state: StateType;
  partsWithRef: Array<PartWithRef<Types.ObjType>>;
}

export interface DenormalizationOptions {
  /**
   * if this is true, include refs in the denormalized output. otherwise, omit
   * them.
   */
  includeRefs: boolean;
}

export type DenormFn<S extends Types.ObjSchema> = (
  part: S['part'],
  ctx: DenormContext,
  opts: DenormalizationOptions
) => Types.WholeFromType<S['type']>;

export type FullDenormFn<S extends Types.ObjSchema> = (
  ref: Types.PartRefFromType<S['type']>,
  ctx: DenormContext,
  opts?: DenormalizationOptions
) => Types.WholeFromTypeWithRef<S['type']>;

export type DenormFunctionMap = {
  [T in Types.ObjType]: FullDenormFn<Types.ObjSchemaFromType<T>>;
};

export function normFn<S extends Types.ObjSchema>(
  type: S['type'],
  userNormFn: NormFn<S>
): FullNormFn<S> {
  return (whole, ctx) => {
    const part = userNormFn(whole, ctx);

    const ref = {
      type,
      viewID: ctx.viewID,
      id: ID(),
    } as Types.PartRefFromType<S['type']>;
    ctx.result.push({ref, part});

    return ref;
  };
}

export function lookupPart<R extends Types.AllPartRefs>(
  state: StateType,
  partRef: R
): Types.PartFromType<R['type']> {
  const partsOfType = state[partRef.type];
  if (partsOfType == null) {
    throw new Error('invalid state');
  }
  const partNorm = partsOfType[partRef.id];
  if (partNorm == null) {
    throw new Error('invalid state');
  }
  return partNorm as Types.PartFromType<R['type']>;
}

export function denormFn<S extends Types.ObjSchema>(
  userDenormFn: DenormFn<S>
): FullDenormFn<S> {
  return (ref, ctx, opts = {includeRefs: true}) => {
    const part = lookupPart(ctx.state, ref);
    ctx.partsWithRef.push({part, ref});
    const whole = userDenormFn(part, ctx, opts);

    if (opts.includeRefs) {
      return {
        ...whole,
        ref,
      } as any;
    } else {
      const refless = {...whole};
      delete (refless as any).ref;
      return refless;
    }
  };
}
