"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.executeNerRule = executeNerRule;
var _lodash = require("lodash");
var _pLimit = _interopRequireDefault(require("p-limit"));
var _inferenceTracing = require("@kbn/inference-tracing");
var _get_entity_mask = require("./get_entity_mask");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

// structured data can end up being a token per character.
// since the limit is 512 tokens, to avoid truncating, set the max to 512
const MAX_TOKENS_PER_DOC = 512;
function chunkText(text, maxChars = MAX_TOKENS_PER_DOC) {
  const chunks = [];
  for (let i = 0; i < text.length; i += maxChars) {
    chunks.push(text.slice(i, i + maxChars));
  }
  return chunks;
}
const DEFAULT_BATCH_SIZE = 1_000;
const DEFAULT_MAX_CONCURRENT_REQUESTS = 7;

/**
 * Executes a NER anonymization rule, by:
 *
 * - For each record, iterate over the key-value pairs.
 * - Split up each value in strings < MAX_TOKENS_PER_DOC, to stay within token limits
 * for NER tasks.
 * - Push each part to an array of strings, track the position in the array, so we can
 * reconstruct the records later.
 * - Create a {text_field:string} document for each part, and run NER inference over
 * these documents in batches.
 * - After retrieving the results:
 *  - Iterate over the _input_ and find the inferred results by key + position
 *  - For each detected entity, replace with a mask
 *  - Append the original value & masked value to `state.anonymizations`
 *  - Return the text with the masked values
 *  - Reconstruct the original record
 */
async function executeNerRule({
  state,
  rule,
  esClient
}) {
  const anonymizations = state.anonymizations.concat();
  const allowedNerEntities = rule.allowedEntityClasses;
  const limiter = (0, _pLimit.default)(DEFAULT_MAX_CONCURRENT_REQUESTS);
  const allTexts = [];
  const allPositions = [];
  state.records.forEach(record => {
    const positionsForRecord = {};
    allPositions.push(positionsForRecord);
    Object.entries(record).forEach(([key, value]) => {
      const positions = [];
      positionsForRecord[key] = positions;
      const texts = chunkText(value);
      texts.forEach(text => {
        const idx = allTexts.length;
        positions.push(idx);
        allTexts.push(text);
      });
    });
  });
  const batched = (0, _lodash.chunk)(allTexts, DEFAULT_BATCH_SIZE);
  const results = (await Promise.all(batched.map(async batch => {
    return await limiter(() => (0, _inferenceTracing.withActiveInferenceSpan)('InferTrainedModel', {
      attributes: {}
    }, async span => {
      var _rule$timeoutSeconds;
      const docs = batch.map(text => ({
        text_field: text
      }));
      span === null || span === void 0 ? void 0 : span.setAttribute('input.value', JSON.stringify(docs));
      const response = await esClient.ml.inferTrainedModel({
        model_id: rule.modelId,
        docs,
        timeout: `${(_rule$timeoutSeconds = rule.timeoutSeconds) !== null && _rule$timeoutSeconds !== void 0 ? _rule$timeoutSeconds : 30}s`
      });
      span === null || span === void 0 ? void 0 : span.setAttribute('output.value', JSON.stringify(response.inference_results));
      return response.inference_results;
    })).catch(error => {
      const errorMessage = error instanceof Error ? error.message : String(error);
      throw new Error(`Inference failed for NER model '${rule.modelId}': ${errorMessage}`, {
        cause: error
      });
    });
  }))).flat();
  const nextRecords = state.records.map((record, idx) => {
    const nerInput = allPositions[idx];
    return (0, _lodash.mapValues)(record, (value, key) => {
      const positions = nerInput[key];
      return positions.map(position => {
        const nerOutput = results[position];
        let offset = 0;
        let anonymizedValue = allTexts[position];
        for (const entity of ((_nerOutput$entities = nerOutput.entities) !== null && _nerOutput$entities !== void 0 ? _nerOutput$entities : []).filter(e => allowedNerEntities ? allowedNerEntities.includes(e.class_name) : true)) {
          var _nerOutput$entities;
          const from = entity.start_pos + offset;
          const to = entity.end_pos + offset;
          const before = anonymizedValue.slice(0, from);
          const after = anonymizedValue.slice(to);
          const entityText = anonymizedValue.slice(from, to);
          const mask = (0, _get_entity_mask.getEntityMask)({
            class_name: entity.class_name,
            value: entityText
          });
          anonymizedValue = before + mask + after;
          offset += mask.length - entityText.length;
          anonymizations.push({
            entity: {
              class_name: entity.class_name,
              value: entityText,
              mask
            },
            rule
          });
        }
        return anonymizedValue;
      }).join('');
    });
  });
  return {
    records: nextRecords,
    anonymizations
  };
}