ecsimsw

TF_source) OpkernelConstruction 본문

TF_source) OpkernelConstruction

JinHwan Kim 2019. 4. 12. 12:20

///"tensorflow/core/framework/op_kernel.h"

class OpKernelConstruction {
 public:
  OpKernelConstruction(DeviceType device_type, DeviceBase* device,
                       Allocator* allocator, const NodeDef* node_def,const OpDef* op_def, FunctionLibraryRuntime* flib,
                       const DataTypeSlice& input_types,
                       const MemoryTypeSlice& input_memory_types,
                       const DataTypeSlice& output_types,
                       const MemoryTypeSlice& output_memory_types,
                       int graph_def_version, Status* status);

  Env* env() const { return device_->env(); }

  // Allocation of tensors during kernel construction:
  //
  // It is legal to temporarily allocate scratch tensor storage during
  // Op kernel construction. Scratch tensors should be allocated using
  // allocate_temp below. Some kernels need to keep tensors in between
  // invocations. If such a Tensor is allocated during kernel
  // construction this must be done using allocate_persistent, and the
  // Op may only store the returned PersistentTensor object. When the
  // Tensor is needed in a subsequent invocation, it can be retrieved
  // from the PersistentTensor using the AccessTensor method. This
  // ensures that the system is made aware of any use of the tensor's
  // allocated memory, which is needed for correctness on asynchronous
  // devices such as GPUs.

  // Allocates a temporary Tensor of the specified type and shape. The
  // Tensor must not be used after kernel construction is
  // complete. See comment above.
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp);

  // Allocates a Tensor of the specified type and shape which the Op
  // plans to maintain as persistent state. out_persistent holds the
  // PersistentTensor which is the object the caller should store. For
  // convenience, if out_tensor is non-null then it will be filled in
  // with a Tensor* pointing to the newly-allocated tensor which the
  // caller can use instead of calling
  // out_persistent->AccessTensor. The caller does not own out_tensor
  // and should not keep a copy of it. See comment above.
  Status allocate_persistent(DataType type, const TensorShape& shape,
                             PersistentTensor* out_persistent,
                             Tensor** out_tensor);

  // User-supplied configuration of this operation.
  const NodeDef& def() const { return *def_; }

  // For inspecting the inputs to this operation.
  int num_inputs() const { return input_types_.size(); }
  DataType input_type(int i) const { return input_types_[i]; }
  const DataTypeSlice& input_types() const { return input_types_; }
  const MemoryTypeSlice& input_memory_types() const {
    return input_memory_types_;
  }

  // For inspecting the outputs expected from this operation.
  int num_outputs() const { return output_types_.size(); }
  DataType output_type(int i) const { return output_types_[i]; }
  const DataTypeSlice& output_types() const { return output_types_; }
  const MemoryTypeSlice& output_memory_types() const {
    return output_memory_types_;
  }

  // If expected_inputs == inputs() and expected_outputs == output_types(),
  // returns OK, else returns INVALID_ARGUMENT with an error message.
  // Recommended for Ops with dynamic signatures.
  Status MatchSignature(const DataTypeSlice expected_inputs,
                        const DataTypeSlice expected_outputs);

  // For recording configuration errors during construction.
  void SetStatus(const Status& status);
  const Status& status() const { return *status_; }

  // Look up the attr with name attr_name and set *value to its value.  If no
  // attr with attr_name is found in def(), or the attr does not have
  // a matching type, a non-ok status will be returned.
  template 
  Status GetAttr(StringPiece attr_name, T* value) const;
  // Return true if the attr_name is defined in def().
  bool HasAttr(StringPiece attr_name) const;
  // Return the device type.
  const DeviceType& device_type() const { return device_type_; }

  // If not nullptr, the kernel can instantiate functions defined in
  // the library. E.g.,
  // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
  FunctionLibraryRuntime* function_library() const { return flib_; }

  // The GraphDef version whose behavior we should follow.
  int graph_def_version() const { return graph_def_version_; }

  // Helper routines for the OP_REQUIRES macros
  void CtxFailure(const Status& s);
  void CtxFailureWithWarning(const Status& s);
  void CtxFailure(const char* file, int line, const Status& s);
  void CtxFailureWithWarning(const char* file, int line, const Status& s);

  // Unrecommended functions: these are functions that have some
  // current uses but are not recommended for use, and may go away at
  // some future major version release.

  // May be used, e.g., to get GPU handles, etc.
  //
  // Currently only used to call MakeTensorFromProto() for
  // implementing ConstantOp for every device.  See comments
  // on Device::MakeTensorFromProto for longer-term replacement
  // ideas.
  DeviceBase* device() const { return device_; }

 private:
  const DeviceType device_type_;
  DeviceBase* const device_;
  Allocator* allocator_;
  const NodeDef* def_;
  const OpDef* op_def_;
  FunctionLibraryRuntime* flib_;
  DataTypeSlice input_types_;
  MemoryTypeSlice input_memory_types_;
  DataTypeSlice output_types_;
  MemoryTypeSlice output_memory_types_;
  const int graph_def_version_;
 
  // Allow op_def_ across from OpKernel, but not from subclasses.
  // TODO(irving): Remove protos from this header entirely.
  friend class OpKernel;

  TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
};



///"tensorflow/core/framework/op_kernel.cc"

OpKernelConstruction::OpKernelConstruction(
    DeviceType device_type, DeviceBase* device, Allocator* allocator,
    const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
    const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
    const DataTypeSlice& output_types,
    const MemoryTypeSlice& output_memory_types, int graph_def_version,
    Status* status)
    : device_type_(std::move(device_type)),
      device_(device),
      allocator_(allocator), 
      def_(node_def),
      op_def_(op_def),
      flib_(flib),
      input_types_(input_types),
      input_memory_types_(input_memory_types),
      output_types_(output_types),
      output_memory_types_(output_memory_types),
      graph_def_version_(graph_def_version),
      status_(status) {}

bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
  return HasNodeAttr(def(), attr_name);
}

                                                   311,23        20%

'Machine Learning > tf_source' 카테고리의 다른 글

TF_source ) KernelRegistry  (0) 2019.04.20
TF_source) Opkernel  (0) 2019.04.12
TF_source ) Factory  (0) 2019.04.08
TF_source ) OpKernelRegistrar  (0) 2019.04.07
TF_source ) REGISTER_KERNELS  (0) 2019.04.07
Comments