You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
2.0 KiB
46 lines
2.0 KiB
diff --git a/caffe2/operators/recurrent_network_op.cc b/caffe2/operators/recurrent_network_op.cc
|
|
index dd4fded..5995e8a 100644
|
|
--- a/caffe2/operators/recurrent_network_op.cc
|
|
+++ b/caffe2/operators/recurrent_network_op.cc
|
|
@@ -1,4 +1,4 @@
|
|
-#include "recurrent_network_op.h"
|
|
+#include "caffe2/operators/recurrent_network_op.h"
|
|
#include "caffe2/core/workspace.h"
|
|
|
|
namespace caffe2 {
|
|
diff --git a/caffe2/operators/recurrent_network_op.h b/caffe2/operators/recurrent_network_op.h
|
|
index 55328e5..ea898bc 100644
|
|
--- a/caffe2/operators/recurrent_network_op.h
|
|
+++ b/caffe2/operators/recurrent_network_op.h
|
|
@@ -762,8 +762,8 @@ class AccumulateInputGradientOp : public Operator<Context> {
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
bool RunOnDevice() override {
|
|
- const auto t =
|
|
- OperatorBase::Input<Tensor<CPUContext>>(0).template data<int32_t>()[0];
|
|
+ const auto& t0 = OperatorBase::Input<Tensor<CPUContext>>(0);
|
|
+ const auto t = t0.template data<int32_t>()[0];
|
|
auto& og = Input(1);
|
|
auto* g = Output(0);
|
|
|
|
diff --git a/caffe2/queue/queue_ops.h b/caffe2/queue/queue_ops.h
|
|
index f2c0a33..642343f 100644
|
|
--- a/caffe2/queue/queue_ops.h
|
|
+++ b/caffe2/queue/queue_ops.h
|
|
@@ -17,13 +17,10 @@ class CreateBlobsQueueOp final : public Operator<Context> {
|
|
name(operator_def.output().Get(0)) {}
|
|
|
|
bool RunOnDevice() override {
|
|
- const auto capacity =
|
|
- OperatorBase::template GetSingleArgument<int>("capacity", 1);
|
|
- const auto numBlobs =
|
|
- OperatorBase::template GetSingleArgument<int>("num_blobs", 1);
|
|
+ const auto capacity = GetSingleArgument("capacity", 1);
|
|
+ const auto numBlobs = GetSingleArgument("num_blobs", 1);
|
|
const auto enforceUniqueName =
|
|
- OperatorBase::template GetSingleArgument<int>(
|
|
- "enforce_unique_name", false);
|
|
+ GetSingleArgument("enforce_unique_name", false);
|
|
const auto fieldNames =
|
|
OperatorBase::template GetRepeatedArgument<std::string>("field_names");
|
|
CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
|
|
|