/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_

#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "learning/brain/experimental/tfrt/native_lowering/kernels/sync_context.h"
#include "learning/infra/mira/mlrt/bytecode/bytecode.h"
#include "learning/infra/mira/mlrt/bytecode/executable.h"
#include "learning/infra/mira/mlrt/interpreter/context.h"
#include "absl/base/call_once.h"
#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tfrt/fallback/cost_recorder.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/tpu/tpu_resources.h"  // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource
#include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
#include "tensorflow/tsl/platform/thread_annotations.h"
#include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
#include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
#include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
#include "tfrt/host_context/execution_context.h"  // from @tf_runtime
#include "tfrt/host_context/function.h"  // from @tf_runtime
#include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
#include "tfrt/support/ref_count.h"  // from @tf_runtime

namespace tensorflow {
namespace tfrt_stub {

// Contains request related info.
struct RequestInfo {
  tfrt::RCReference<tfrt::RequestContext> tfrt_request_context;
  // If this request needs to create a new queue, it is stored here. Otherwise,
  // it can be nullptr.
  std::unique_ptr<WorkQueueInterface> request_queue_owner;
  // The inter-op thread pool to be used for this request, and it must not be
  // nullptr. If `request_queue_owner` is not nullptr, then `request_queue` is
  // the raw pointer inside `request_queue_owner`.
  WorkQueueInterface* request_queue = nullptr;
  // The task runner used by tensorflow::OpKernel.
  std::function<void(std::function<void()>)> runner;
};

// Creates a `RequestInfo` given relative data.
StatusOr<std::unique_ptr<RequestInfo>> CreateRequestInfo(
    const GraphExecutionRunOptions& run_options,
    const SessionMetadata& model_metadata, const Runtime& runtime,
    tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
    tfrt::ResourceContext* resource_context,
    const FallbackState& fallback_state, CostRecorder* cost_recorder = nullptr);

// Runs on a function given input/output and other info.
tensorflow::Status GraphExecutionRunOnFunction(
    const GraphExecutionOptions& options,
    const GraphExecutionRunOptions& run_options,
    absl::string_view signature_name, const tfrt::Function& func,
    absl::Span<const tensorflow::Tensor> inputs,
    std::vector<tensorflow::Tensor>* outputs,
    tfrt::ResourceContext* resource_context, const Runtime& runtime,
    const FallbackState& fallback_state,
    tfrt::RequestDeadlineTracker* req_deadline_tracker,
    CostRecorder* cost_recorder = nullptr);

// Creates a ResourceContext and populate it with per model resource from
// Runtime. If `tpu_target` is set to kTpurt, also call a special
// `AddTpuResources` function to populate TPU related resources for tpurt.
//
// TODO(b/178227859): Remove the need for the special handling for TPU here.
std::unique_ptr<tfrt::ResourceContext> CreateResourceContext(
    const Runtime& runtime, tfrt::tpu::TpuModelResource* tpu_model_resource,
    tensorflow::TfrtDeviceInfraTarget tpu_target);

// Loads (if not yet) and runs a subgraph in a graph as per each request.
class GraphExecutor {
 public:
  using Options = GraphExecutionOptions;
  using RunOptions = GraphExecutionRunOptions;

  // Stores BEF-related data.
  struct BefContext {
    BefContext(tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file)
        : bef(std::move(bef)), bef_file(std::move(bef_file)) {}

    tfrt::BefBuffer bef;
    tfrt::RCReference<tfrt::BEFFile> bef_file;
  };

  // The loading result of a `ClientGraph`.
  class LoadedClientGraph {
   public:
    LoadedClientGraph(
        std::string name,
        std::unique_ptr<tfrt::ResourceContext> resource_context,
        std::unique_ptr<mlir::MLIRContext> mlir_context,
        mlir::OwningOpRef<mlir::ModuleOp> tfrt_mlir,
        std::shared_ptr<BefContext> bef_context,
        mlrt::bc::Buffer bytecode_buffer,
        std::unique_ptr<mlrt::LoadedExecutable> bytecode_executable)
        : name_(std::move(name)),
          resource_context_(std::move(resource_context)),
          mlir_context_(std::move(mlir_context)),
          tfrt_mlir_(std::move(tfrt_mlir)),
          bef_context_(std::move(bef_context)),
          bytecode_buffer_(std::move(bytecode_buffer)),
          bytecode_executable_(std::move(bytecode_executable)) {}

    // Returns a `CostRecorder` if none has been created before for this
    // `LoadedClientGraph`.
    std::unique_ptr<CostRecorder> MaybeCreateCostRecorder() const;

    // Updates the op cost values in this `LoadedClientGraph` with records from
    // `cost_recorder`.
    Status UpdateCost(const CostRecorder& cost_recorder,
                      const Runtime& runtime);

    // Getters.
    std::shared_ptr<BefContext> bef_context() const {
      tensorflow::mutex_lock lock(bef_context_mu_);
      return bef_context_;
    }
    absl::string_view name() const { return name_; }
    tfrt::ResourceContext& resource_context() const {
      return *resource_context_;
    }
    mlrt::LoadedExecutable* bytecode_executable() const {
      return bytecode_executable_.get();
    }

   private:
    std::string name_;
    std::unique_ptr<tfrt::ResourceContext> resource_context_;
    std::unique_ptr<mlir::MLIRContext> mlir_context_;
    // Thread-safety resulted from `create_cost_recorder_once_`.
    mlir::OwningOpRef<mlir::ModuleOp> tfrt_mlir_;
    // Only one of `bef_context_` or `bytecode_executable_` should be filled for
    // a single `LoadedClientGraph`.
    mutable tensorflow::mutex bef_context_mu_;
    // Can be updated if online cost analysis is enabled.
    std::shared_ptr<BefContext> bef_context_ TF_GUARDED_BY(bef_context_mu_);
    mlrt::bc::Buffer bytecode_buffer_;
    std::unique_ptr<mlrt::LoadedExecutable> bytecode_executable_ = nullptr;
    mutable absl::once_flag create_cost_recorder_once_;
  };

  // A subgraph constructed by specifying input/output tensors.
  struct ClientGraph {
    // A unique name by joining all the input/output/target names.
    std::string name;
    // The feed nodes for the corresponding inputs, but they might not be in the
    // original order and if there are more than one original inputs mapped to
    // the same feed node, only one is picked here.
    tensorflow::GraphImportConfig::InputArrays input_nodes;
    // The fetch nodes for the outputs, which should be in the original order.
    std::vector<std::string> output_nodes;
    // The target nodes that should be run but not returned as outputs.
    std::vector<std::string> target_nodes;
  };

  // Creates a `GraphExecutor` given the args.
  static StatusOr<std::unique_ptr<GraphExecutor>> Create(
      Options options, const FallbackState& fallback_state,
      tfrt::tpu::TpuModelResource* tpu_model_resource,
      tensorflow::GraphDef graph_def,
      std::unique_ptr<mlrt::KernelRegistry> kernel_registry);

  // Ctor. Public for `Create()`. Do not use directly.
  GraphExecutor(Options options, const FallbackState& fallback_state,
                tfrt::tpu::TpuModelResource* tpu_model_resource,
                std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
                    graph_execution_state,
                std::unique_ptr<mlrt::KernelRegistry> kernel_registry)
      : options_(std::move(options)),
        fallback_state_(fallback_state),
        tpu_model_resource_(tpu_model_resource),
        graph_execution_state_(std::move(graph_execution_state)),
        req_deadline_tracker_(
            options_.runtime->core_runtime()->GetHostContext()),
        kernel_registry_(std::move(kernel_registry)) {}

  // Runs on the graph according to given input/output.
  tensorflow::Status Run(
      const RunOptions& run_options,
      absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
      absl::Span<const std::string> output_tensor_names,
      absl::Span<const std::string> target_tensor_names,
      std::vector<tensorflow::Tensor>* outputs);

  // Runs the graph identified by `graph_name` using the input `inputs` and
  // stores the output of the execution in `outputs`. It is the client's
  // responsibility to ensure `graph_name` corresponds to logically different
  // graphs, since this name is used to lookup compiled graphs in the cache. The
  // graph is run synchronously with the TFRT interpreter.
  tensorflow::Status RunWithSyncInterpreter(
      const std::string& graph_name, absl::Span<mlrt::Value> input_values,
      absl::Span<const std::string> input_names,
      absl::Span<const tensorflow::DataType> input_dtypes,
      absl::Span<const std::string> output_tensor_names,
      absl::Span<const std::string> target_tensor_names,
      absl::Span<mlrt::Value> outputs);

  // Extends the current graph by `graph`.
  tensorflow::Status Extend(const GraphDef& graph);

  tensorflow::tfrt_stub::TfrtGraphExecutionState& graph_execution_state()
      const {
    return *graph_execution_state_;
  }

  // Returns the underlying runtime.
  const tensorflow::tfrt_stub::Runtime& runtime() const {
    DCHECK(options_.runtime);
    return *options_.runtime;
  }

 private:
  // A set of methods to load a client graph.
  StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>> LoadClientGraph(
      const GraphExecutor::ClientGraph& client_graph,
      tensorflow::tfrt_stub::WorkQueueInterface* work_queue);
  StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
  ImportAndCompileClientGraph(const GraphExecutor::ClientGraph& client_graph);
  tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
  ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph& client_graph,
                                mlir::MLIRContext* context) const;
  StatusOr<tfrt::BefBuffer> CompileMlirModuleToBef(mlir::ModuleOp module) const;
  tensorflow::Status InitBef(
      tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context,
      tensorflow::tfrt_stub::WorkQueueInterface* work_queue);

  tensorflow::Status InitBytecode(LoadedClientGraph* loaded_graph);

  // Returns a `LoadedClientGraph` given input/output tensor info. If there is
  // no existing one yet, creates one first.
  StatusOr<std::reference_wrapper<GraphExecutor::LoadedClientGraph>>
  GetOrCreateLoadedClientGraph(
      const RunOptions& run_options,
      absl::Span<const std::string> input_tensor_names,
      absl::Span<const tensorflow::DataType> input_tensor_dtypes,
      absl::Span<const std::string> output_tensor_names,
      absl::Span<const std::string> target_tensor_names,
      tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
      std::optional<const std::string> graph_name = std::nullopt)
      TF_LOCKS_EXCLUDED(loaded_client_graphs_mu_);

  Options options_;
  std::reference_wrapper<const FallbackState> fallback_state_;
  tfrt::tpu::TpuModelResource* tpu_model_resource_;  // NOT owned.

  std::unique_ptr<tensorflow::tfrt_stub::TfrtGraphExecutionState>
      graph_execution_state_;

  tfrt::RequestDeadlineTracker req_deadline_tracker_;

  tensorflow::mutex loaded_client_graphs_mu_;
  // Caches `LoadedClientGraph` by the joined name.
  // For pointer stability of values in `absl::flat_hash_map<>`, additional
  // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.)
  absl::flat_hash_map<std::string /*joined_name*/,
                      std::unique_ptr<LoadedClientGraph>>
      loaded_client_graphs_ TF_GUARDED_BY(loaded_client_graphs_mu_);

  std::unique_ptr<mlrt::KernelRegistry> kernel_registry_;
};

}  // namespace tfrt_stub
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
