ecsimsw

TF_source ) matmul_op.cc 본문

TF_source ) matmul_op.cc

JinHwan Kim 2019. 4. 30. 20:15

/// matmul_op.cc 

#endif  // GOOGLE_CUDA

template 
class MatMulOp : public OpKernel {
 public:
  explicit MatMulOp(OpKernelConstruction* ctx)
      : OpKernel(ctx), algorithms_set_already_(false) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));

    LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
        ctx, &algorithms_, &algorithms_set_already_);
    use_autotune_ = MatmulAutotuneEnable();
  }

  void Compute(OpKernelContext* ctx) override {
    const Tensor& a = ctx->input(0);
    const Tensor& b = ctx->input(1);

    // Check that the dimensions of the two matrices are valid.
    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
                errors::InvalidArgument("In[0] is not a matrix"));
    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
                errors::InvalidArgument("In[1] is not a matrix"));
    Eigen::array<Eigen::IndexPair, 1> dim_pair;
    dim_pair[0].first = transpose_a_ ? 0 : 1;
    dim_pair[0].second = transpose_b_ ? 1 : 0;

    OP_REQUIRES(
        ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
        errors::InvalidArgument(
            "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
            ", In[1]: ", b.shape().DebugString()));
    int a_dim_remaining = 1 - dim_pair[0].first;
    int b_dim_remaining = 1 - dim_pair[0].second;
    TensorShape out_shape(
        {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
    Tensor* out = nullptr;
    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));

   if (out->NumElements() == 0) {
      // If a has shape [0, x] or b has shape [x, 0], the output shape
      // is a 0-element matrix, so there is nothing to do.
      return;
    }

    if (a.NumElements() == 0 || b.NumElements() == 0) {
      // If a has shape [x, 0] and b has shape [0, y], the
      // output shape is [x, y] where x and y are non-zero, so we fill
      // the output with zeros.
      functor::SetZeroFunctor<Device, T> f;
      f(ctx->eigen_device(), out->flat());
      return;
    }

    LaunchMatMul<Device, T, USE_CUBLAS>::launch(
        ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
  }

 private:
  std::vector algorithms_;
  bool algorithms_set_already_;
  bool use_autotune_;
  bool transpose_a_;
  bool transpose_b_;
};

 

///OpkernelConstruction ctx

private:
  const DeviceType device_type_;
  DeviceBase* const device_;
  Allocator* allocator_;
  const NodeDef* def_;
  const OpDef* op_def_;
  FunctionLibraryRuntime* flib_;
  DataTypeSlice input_types_;
  MemoryTypeSlice input_memory_types_;
  DataTypeSlice output_types_;
  MemoryTypeSlice output_memory_types_;
  const int graph_def_version_;
  Status* status_;

'Machine Learning > tf_source' 카테고리의 다른 글

TF_source ) 2019-04-25  (0) 2019.04.25
TF_source ) kernel_builder  (0) 2019.04.23
What's the difference between user registers and kernel registers?  (0) 2019.04.20
TF_source ) KernelRegistry  (0) 2019.04.20
TF_source) Opkernel  (0) 2019.04.12
Comments