ecsimsw
TF_source) OpkernelConstruction 본문
///"tensorflow/core/framework/op_kernel.h"
class OpKernelConstruction {
public:
OpKernelConstruction(DeviceType device_type, DeviceBase* device,
Allocator* allocator, const NodeDef* node_def,const OpDef* op_def, FunctionLibraryRuntime* flib,
const DataTypeSlice& input_types,
const MemoryTypeSlice& input_memory_types,
const DataTypeSlice& output_types,
const MemoryTypeSlice& output_memory_types,
int graph_def_version, Status* status);
Env* env() const { return device_->env(); }
// Allocation of tensors during kernel construction:
//
// It is legal to temporarily allocate scratch tensor storage during
// Op kernel construction. Scratch tensors should be allocated using
// allocate_temp below. Some kernels need to keep tensors in between
// invocations. If such a Tensor is allocated during kernel
// construction this must be done using allocate_persistent, and the
// Op may only store the returned PersistentTensor object. When the
// Tensor is needed in a subsequent invocation, it can be retrieved
// from the PersistentTensor using the AccessTensor method. This
// ensures that the system is made aware of any use of the tensor's
// allocated memory, which is needed for correctness on asynchronous
// devices such as GPUs.
// Allocates a temporary Tensor of the specified type and shape. The
// Tensor must not be used after kernel construction is
// complete. See comment above.
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp);
// Allocates a Tensor of the specified type and shape which the Op
// plans to maintain as persistent state. out_persistent holds the
// PersistentTensor which is the object the caller should store. For
// convenience, if out_tensor is non-null then it will be filled in
// with a Tensor* pointing to the newly-allocated tensor which the
// caller can use instead of calling
// out_persistent->AccessTensor. The caller does not own out_tensor
// and should not keep a copy of it. See comment above.
Status allocate_persistent(DataType type, const TensorShape& shape,
PersistentTensor* out_persistent,
Tensor** out_tensor);
// User-supplied configuration of this operation.
const NodeDef& def() const { return *def_; }
// For inspecting the inputs to this operation.
int num_inputs() const { return input_types_.size(); }
DataType input_type(int i) const { return input_types_[i]; }
const DataTypeSlice& input_types() const { return input_types_; }
const MemoryTypeSlice& input_memory_types() const {
return input_memory_types_;
}
// For inspecting the outputs expected from this operation.
int num_outputs() const { return output_types_.size(); }
DataType output_type(int i) const { return output_types_[i]; }
const DataTypeSlice& output_types() const { return output_types_; }
const MemoryTypeSlice& output_memory_types() const {
return output_memory_types_;
}
// If expected_inputs == inputs() and expected_outputs == output_types(),
// returns OK, else returns INVALID_ARGUMENT with an error message.
// Recommended for Ops with dynamic signatures.
Status MatchSignature(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs);
// For recording configuration errors during construction.
void SetStatus(const Status& status);
const Status& status() const { return *status_; }
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in def(), or the attr does not have
// a matching type, a non-ok status will be returned.
template
Status GetAttr(StringPiece attr_name, T* value) const;
// Return true if the attr_name is defined in def().
bool HasAttr(StringPiece attr_name) const;
// Return the device type.
const DeviceType& device_type() const { return device_type_; }
// If not nullptr, the kernel can instantiate functions defined in
// the library. E.g.,
// CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
FunctionLibraryRuntime* function_library() const { return flib_; }
// The GraphDef version whose behavior we should follow.
int graph_def_version() const { return graph_def_version_; }
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
void CtxFailureWithWarning(const Status& s);
void CtxFailure(const char* file, int line, const Status& s);
void CtxFailureWithWarning(const char* file, int line, const Status& s);
// Unrecommended functions: these are functions that have some
// current uses but are not recommended for use, and may go away at
// some future major version release.
// May be used, e.g., to get GPU handles, etc.
//
// Currently only used to call MakeTensorFromProto() for
// implementing ConstantOp for every device. See comments
// on Device::MakeTensorFromProto for longer-term replacement
// ideas.
DeviceBase* device() const { return device_; }
private:
const DeviceType device_type_;
DeviceBase* const device_;
Allocator* allocator_;
const NodeDef* def_;
const OpDef* op_def_;
FunctionLibraryRuntime* flib_;
DataTypeSlice input_types_;
MemoryTypeSlice input_memory_types_;
DataTypeSlice output_types_;
MemoryTypeSlice output_memory_types_;
const int graph_def_version_;
// Allow op_def_ across from OpKernel, but not from subclasses.
// TODO(irving): Remove protos from this header entirely.
friend class OpKernel;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
};
///"tensorflow/core/framework/op_kernel.cc"
OpKernelConstruction::OpKernelConstruction(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
const DataTypeSlice& output_types,
const MemoryTypeSlice& output_memory_types, int graph_def_version,
Status* status)
: device_type_(std::move(device_type)),
device_(device),
allocator_(allocator),
def_(node_def),
op_def_(op_def),
flib_(flib),
input_types_(input_types),
input_memory_types_(input_memory_types),
output_types_(output_types),
output_memory_types_(output_memory_types),
graph_def_version_(graph_def_version),
status_(status) {}
bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
return HasNodeAttr(def(), attr_name);
}
311,23 20%
'Machine Learning > tf_source' 카테고리의 다른 글
TF_source ) KernelRegistry (0) | 2019.04.20 |
---|---|
TF_source) Opkernel (0) | 2019.04.12 |
TF_source ) Factory (0) | 2019.04.08 |
TF_source ) OpKernelRegistrar (0) | 2019.04.07 |
TF_source ) REGISTER_KERNELS (0) | 2019.04.07 |