fix the cves of tensorflow
This commit is contained in:
parent
4c9abf5ce5
commit
736ee0c8ef
252
CVE-2020-26267-1.patch
Normal file
252
CVE-2020-26267-1.patch
Normal file
@ -0,0 +1,252 @@
|
||||
From 1a11d01c1fdd6683e9aa210dccde81de127dbf3e Mon Sep 17 00:00:00 2001
|
||||
From: Kaixi Hou <kaixih@nvidia.com>
|
||||
Date: Mon, 14 Sep 2020 15:52:22 -0700
|
||||
Subject: [PATCH 1/1] support reduce ops for 5d tensors in layout optimizer
|
||||
|
||||
---
|
||||
.../generic_layout_optimizer_transposer.cc | 27 +++++++++-
|
||||
tensorflow/core/kernels/data_format_ops.cc | 10 ++--
|
||||
tensorflow/core/kernels/data_format_ops.h | 53 ++++++++++++++-----
|
||||
.../python/grappler/layout_optimizer_test.py | 39 ++++++++++++++
|
||||
tensorflow/python/ops/nn_test.py | 27 ++++++++++
|
||||
5 files changed, 136 insertions(+), 20 deletions(-)
|
||||
|
||||
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
||||
index ab7d8fcd..fbbeffc7 100644
|
||||
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
||||
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
||||
@@ -1283,11 +1283,31 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
||||
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsReduceOp(*node->node()));
|
||||
- if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
|
||||
+ const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
+ const auto& shape = output_shape_attr->list().shape(0);
|
||||
+ const int rank = shape.dim_size();
|
||||
+ std::string src_format = context->src_format;
|
||||
+ std::string dst_format = context->dst_format;
|
||||
+ // Update the format from 4D to 5D layout if necessary.
|
||||
+ if (rank == 5) {
|
||||
+ std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
+ std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
+ context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
+ dst_format_3d);
|
||||
+ }
|
||||
+ if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
|
||||
!IsReduceAxisSupported(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
+ // Change back to the original layout due to early exit.
|
||||
+ if (rank == 5) {
|
||||
+ context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
+ dst_format);
|
||||
+ }
|
||||
return Status::OK();
|
||||
}
|
||||
+ VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
+ << "' with op '" << node->GetOp() << "' from data format '"
|
||||
+ << context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
|
||||
@@ -1295,6 +1315,11 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
}
|
||||
+ // Change back the format from 5D to 4D layout.
|
||||
+ if (rank == 5) {
|
||||
+ context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
+ dst_format);
|
||||
+ }
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
|
||||
index 181aa1b8..e9c71f17 100644
|
||||
--- a/tensorflow/core/kernels/data_format_ops.cc
|
||||
+++ b/tensorflow/core/kernels/data_format_ops.cc
|
||||
@@ -37,14 +37,14 @@ class DataFormatDimMapOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
- OP_REQUIRES(context, src_format.size() == 4,
|
||||
+ OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
- "Source format must of length 4, received src_format = ",
|
||||
- src_format)));
|
||||
+ "Source format must of length 4 or 5, received "
|
||||
+ "src_format = ", src_format)));
|
||||
OP_REQUIRES(
|
||||
- context, dst_format.size() == 4,
|
||||
+ context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
- "Destination format must of length 4, received dst_format = ",
|
||||
+ "Destination format must of length 4 or 5, received dst_format = ",
|
||||
dst_format)));
|
||||
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
||||
for (int i = 0; i < src_format.size(); ++i) {
|
||||
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
|
||||
index bc416fa7..89b54901 100644
|
||||
--- a/tensorflow/core/kernels/data_format_ops.h
|
||||
+++ b/tensorflow/core/kernels/data_format_ops.h
|
||||
@@ -28,24 +28,49 @@ template <typename Device, typename T>
|
||||
struct DataFormatDimMap {
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
||||
typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
|
||||
- auto zero = x.constant(0);
|
||||
- auto one = x.constant(1);
|
||||
- auto two = x.constant(2);
|
||||
+ if (dst.size() == 4) {
|
||||
+ auto zero = x.constant(0);
|
||||
+ auto one = x.constant(1);
|
||||
+ auto two = x.constant(2);
|
||||
|
||||
- auto f_zero = x.constant(dst(0));
|
||||
- auto f_one = x.constant(dst(1));
|
||||
- auto f_two = x.constant(dst(2));
|
||||
- auto f_three = x.constant(dst(3));
|
||||
+ auto f_zero = x.constant(dst(0));
|
||||
+ auto f_one = x.constant(dst(1));
|
||||
+ auto f_two = x.constant(dst(2));
|
||||
+ auto f_three = x.constant(dst(3));
|
||||
|
||||
- auto four = x.constant(4);
|
||||
- auto x_mod = (x + four) % 4;
|
||||
+ auto four = x.constant(4);
|
||||
+ auto x_mod = (x + four) % 4;
|
||||
|
||||
- auto is_zero = (x_mod == zero);
|
||||
- auto is_one = (x_mod == one);
|
||||
- auto is_two = (x_mod == two);
|
||||
+ auto is_zero = (x_mod == zero);
|
||||
+ auto is_one = (x_mod == one);
|
||||
+ auto is_two = (x_mod == two);
|
||||
|
||||
- y.device(d) = is_zero.select(
|
||||
- f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
||||
+ y.device(d) = is_zero.select(
|
||||
+ f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
||||
+ } else {
|
||||
+ auto zero = x.constant(0);
|
||||
+ auto one = x.constant(1);
|
||||
+ auto two = x.constant(2);
|
||||
+ auto three = x.constant(3);
|
||||
+
|
||||
+ auto f_zero = x.constant(dst(0));
|
||||
+ auto f_one = x.constant(dst(1));
|
||||
+ auto f_two = x.constant(dst(2));
|
||||
+ auto f_three = x.constant(dst(3));
|
||||
+ auto f_four = x.constant(dst(4));
|
||||
+
|
||||
+ auto five = x.constant(5);
|
||||
+ auto x_mod = (x + five) % 5;
|
||||
+
|
||||
+ auto is_zero = (x_mod == zero);
|
||||
+ auto is_one = (x_mod == one);
|
||||
+ auto is_two = (x_mod == two);
|
||||
+ auto is_three = (x_mod == three);
|
||||
+
|
||||
+ y.device(d) = is_zero.select(
|
||||
+ f_zero, is_one.select(f_one, is_two.select(f_two,
|
||||
+ is_three.select(f_three, f_four))));
|
||||
+ }
|
||||
}
|
||||
};
|
||||
|
||||
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
|
||||
index 10f86980..f90da7ed 100644
|
||||
--- a/tensorflow/python/grappler/layout_optimizer_test.py
|
||||
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
|
||||
@@ -215,6 +215,9 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
def _assert_map_nhwc_to_nchw(self, name, nodes):
|
||||
self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
|
||||
|
||||
+ def _assert_map_ndhwc_to_ncdhw(self, name, nodes):
|
||||
+ self.assertIn(name + '-DataFormatDimMapNDHWCToNCDHW-LayoutOptimizer', nodes)
|
||||
+
|
||||
def _assert_vec_nchw_to_nhwc(self, name, nodes):
|
||||
self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
|
||||
|
||||
@@ -286,6 +289,42 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
+ @test_util.deprecated_graph_mode_only
|
||||
+ def testReduceOpsFor5DTensors(self):
|
||||
+ if test.is_gpu_available(cuda_only=True):
|
||||
+ random_seed.set_random_seed(0)
|
||||
+ x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||
+ w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||
+ gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
+ beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
+ conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
|
||||
+ y = math_ops.reduce_mean(conv3d, [0, 1, 2, 3], keepdims=True)
|
||||
+ output = array_ops.identity(y)
|
||||
+
|
||||
+ with session.Session(config=_get_config(False)) as sess:
|
||||
+ output_val_ref = sess.run(output)
|
||||
+
|
||||
+ with session.Session(config=_get_config()) as sess:
|
||||
+ metadata = config_pb2.RunMetadata()
|
||||
+ output_val = sess.run(output, run_metadata=metadata)
|
||||
+
|
||||
+ nodes = []
|
||||
+ num_transposes = 0
|
||||
+ for node in metadata.cost_graph.node:
|
||||
+ if _is_transpose(node.name):
|
||||
+ num_transposes += 1
|
||||
+ nodes.append(node.name)
|
||||
+ print(node.name)
|
||||
+
|
||||
+ # The reduce op Mean needs to dim map the input reduce index to NCDHW.
|
||||
+ # Then, the output needs to be tranposed back to NDHWC.
|
||||
+ expected_num_transposes = 2
|
||||
+ self.assertEqual(expected_num_transposes, num_transposes)
|
||||
+ self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||
+ self._assert_map_ndhwc_to_ncdhw('Mean-1', nodes)
|
||||
+ self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
|
||||
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
+
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testSplitWithNonConstAxis(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
|
||||
index bfe11b63..55d11a35 100644
|
||||
--- a/tensorflow/python/ops/nn_test.py
|
||||
+++ b/tensorflow/python/ops/nn_test.py
|
||||
@@ -1207,6 +1207,33 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, y_val_expected)
|
||||
|
||||
+ def testNDHWCtoNCDHW(self):
|
||||
+ x_val = [1, -4, -3, -2]
|
||||
+ y_val_expected = [2, 2, 3, 4]
|
||||
+ x = constant_op.constant(x_val)
|
||||
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW")
|
||||
+ with test_util.use_gpu():
|
||||
+ y_val = self.evaluate(y)
|
||||
+ self.assertAllEqual(y_val, y_val_expected)
|
||||
+
|
||||
+ def testNDHWCtoDHWNC(self):
|
||||
+ x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
+ y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4]
|
||||
+ x = constant_op.constant(x_val)
|
||||
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC")
|
||||
+ with test_util.use_gpu():
|
||||
+ y_val = self.evaluate(y)
|
||||
+ self.assertAllEqual(y_val, y_val_expected)
|
||||
+
|
||||
+ def testDNHWCtoWHDCN(self):
|
||||
+ x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||
+ y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3]
|
||||
+ x = constant_op.constant(x_val)
|
||||
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN")
|
||||
+ with test_util.use_gpu():
|
||||
+ y_val = self.evaluate(y)
|
||||
+ self.assertAllEqual(y_val, y_val_expected)
|
||||
+
|
||||
def testArbitraryASCII(self):
|
||||
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
||||
--
|
||||
2.27.0
|
||||
|
||||
266
CVE-2020-26267-2.patch
Normal file
266
CVE-2020-26267-2.patch
Normal file
@ -0,0 +1,266 @@
|
||||
From ebc70b7a592420d3d2f359e4b1694c236b82c7ae Mon Sep 17 00:00:00 2001
|
||||
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||||
Date: Mon, 7 Dec 2020 11:15:21 -0800
|
||||
Subject: [PATCH] Validate that `DataFormat*` attributes form a permutation.
|
||||
|
||||
The `src_format` and `dst_format` attributes for the `DataFormatDimMap` and `DataFormatVecPermute` raw ops are supposed to determine a permutation. However, this was not validated and could result in unitialized memory accesses as well as writes outside of bounds and potential crashes.
|
||||
|
||||
While here, we also test that the format attributes have the needed length, add tests for all validation failure cases, remove unnecessary calls to `strings::StrCat`, and fix a few grammar errors.
|
||||
|
||||
This will be cherry-picked on the supported release branches.
|
||||
|
||||
PiperOrigin-RevId: 346135579
|
||||
Change-Id: I1c76392382c89ad8f072d5bc93d70669851eb404
|
||||
---
|
||||
tensorflow/core/kernels/data_format_ops.cc | 72 ++++++++++++++--
|
||||
tensorflow/python/ops/nn_test.py | 96 ++++++++++++++++++++++
|
||||
2 files changed, 161 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
|
||||
index e9c71f17..abe2fbc3 100644
|
||||
--- a/tensorflow/core/kernels/data_format_ops.cc
|
||||
+++ b/tensorflow/core/kernels/data_format_ops.cc
|
||||
@@ -18,16 +18,52 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/data_format_ops.h"
|
||||
+
|
||||
+#include <map>
|
||||
+
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
+#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
+// Ensure that `src` and `dst` define a valid permutation.
|
||||
+// Ops defined in this file assume that user specifies a permutation via two
|
||||
+// string attributes. This check validates that these attributes properly define
|
||||
+// it to prevent security vulnerabilities.
|
||||
+static bool IsValidPermutation(const std::string& src, const std::string& dst) {
|
||||
+ if (src.size() != dst.size()) {
|
||||
+ return false;
|
||||
+ }
|
||||
+
|
||||
+ std::map<char, bool> characters;
|
||||
+
|
||||
+ // Every character in `src` must be present only once
|
||||
+ for (const auto c : src) {
|
||||
+ if (characters[c]) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ characters[c] = true;
|
||||
+ }
|
||||
+
|
||||
+ // Every character in `dst` must show up in `src` exactly once
|
||||
+ for (const auto c : dst) {
|
||||
+ if (!characters[c]) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ characters[c] = false;
|
||||
+ }
|
||||
+
|
||||
+ // At this point, characters[] has been switched to true and false exactly
|
||||
+ // once for all character in `src` (and `dst`) so we have a valid permutation
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
template <typename Device, typename T>
|
||||
class DataFormatDimMapOp : public OpKernel {
|
||||
public:
|
||||
@@ -38,14 +74,18 @@ class DataFormatDimMapOp : public OpKernel {
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||
- errors::InvalidArgument(strings::StrCat(
|
||||
- "Source format must of length 4 or 5, received "
|
||||
- "src_format = ", src_format)));
|
||||
+ errors::InvalidArgument(
|
||||
+ "Source format must be of length 4 or 5, received "
|
||||
+ "src_format = ", src_format));
|
||||
+ OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
+ errors::InvalidArgument("Destination format must be of length "
|
||||
+ "4 or 5, received dst_format = ",
|
||||
+ dst_format));
|
||||
OP_REQUIRES(
|
||||
- context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
- errors::InvalidArgument(strings::StrCat(
|
||||
- "Destination format must of length 4 or 5, received dst_format = ",
|
||||
- dst_format)));
|
||||
+ context, IsValidPermutation(src_format, dst_format),
|
||||
+ errors::InvalidArgument(
|
||||
+ "Destination and source format must determine a permutation, got ",
|
||||
+ src_format, " and ", dst_format));
|
||||
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
||||
for (int i = 0; i < src_format.size(); ++i) {
|
||||
for (int j = 0; j < dst_format.size(); ++j) {
|
||||
@@ -77,8 +117,22 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||
: OpKernel(context) {
|
||||
string src_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||
+ OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||
+ errors::InvalidArgument(
|
||||
+ "Source format must be of length 4 or 5, received "
|
||||
+ "src_format = ",
|
||||
+ src_format));
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
+ OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
+ errors::InvalidArgument("Destination format must be of length "
|
||||
+ "4 or 5, received dst_format = ",
|
||||
+ dst_format));
|
||||
+ OP_REQUIRES(
|
||||
+ context, IsValidPermutation(src_format, dst_format),
|
||||
+ errors::InvalidArgument(
|
||||
+ "Destination and source format must determine a permutation, got ",
|
||||
+ src_format, " and ", dst_format));
|
||||
src_format_ = src_format;
|
||||
dst_format_ = dst_format;
|
||||
}
|
||||
@@ -124,6 +178,10 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||
};
|
||||
keep_only_spatial_dimensions(&src_format_str);
|
||||
keep_only_spatial_dimensions(&dst_format_str);
|
||||
+ OP_REQUIRES(context,
|
||||
+ src_format_str.size() == 2 && dst_format_str.size() == 2,
|
||||
+ errors::InvalidArgument(
|
||||
+ "Format specifier must contain H and W for 2D case"));
|
||||
}
|
||||
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
|
||||
|
||||
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
|
||||
index 55d11a35..d2094a7d 100644
|
||||
--- a/tensorflow/python/ops/nn_test.py
|
||||
+++ b/tensorflow/python/ops/nn_test.py
|
||||
@@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
+from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
@@ -1234,6 +1235,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, y_val_expected)
|
||||
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testArbitraryASCII(self):
|
||||
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
||||
@@ -1243,6 +1245,46 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, y_val_expected)
|
||||
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testInvalidLength(self):
|
||||
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
+ "Source format must be of length 4 or 5"):
|
||||
+ op = nn_ops.data_format_dim_map(
|
||||
+ x, src_format="12345678", dst_format="87654321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testDuplicateSrc(self):
|
||||
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testDuplicateDst(self):
|
||||
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testExtraSpecifiers(self):
|
||||
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
|
||||
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||
|
||||
@@ -1344,6 +1386,60 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
|
||||
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testInvalidLength(self):
|
||||
+ x = [0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
+ "Source format must be of length 4 or 5"):
|
||||
+ op = nn_ops.data_format_vec_permute(
|
||||
+ x, src_format="12345678", dst_format="87654321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testDuplicateSrc(self):
|
||||
+ x = [0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_vec_permute(
|
||||
+ x, src_format="1233", dst_format="4321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testDuplicateDst(self):
|
||||
+ x = [0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_vec_permute(
|
||||
+ x, src_format="1234", dst_format="3321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def testExtraSpecifiers(self):
|
||||
+ x = [0, 1, 2, 3]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Destination and source format must determine a permutation"):
|
||||
+ op = nn_ops.data_format_vec_permute(
|
||||
+ x, src_format="1234", dst_format="5321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
+ def test2DNoWH(self):
|
||||
+ x = [[0, 1], [2, 3]]
|
||||
+ with self.assertRaisesRegex(
|
||||
+ errors.InvalidArgumentError,
|
||||
+ "Format specifier must contain H and W for 2D case"):
|
||||
+ op = nn_ops.data_format_vec_permute(
|
||||
+ x, src_format="1234", dst_format="4321")
|
||||
+ with test_util.use_gpu():
|
||||
+ self.evaluate(op)
|
||||
+
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class AvgPoolTest(test_lib.TestCase):
|
||||
--
|
||||
2.27.0
|
||||
|
||||
41
CVE-2021-29515.patch
Normal file
41
CVE-2021-29515.patch
Normal file
@ -0,0 +1,41 @@
|
||||
From a7116dd3913c4a4afd2a3a938573aa7c785fdfc6 Mon Sep 17 00:00:00 2001
|
||||
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||||
Date: Sat, 17 Apr 2021 20:55:53 -0700
|
||||
Subject: [PATCH] Validate `MatrixDiagV{2,3}` arguments to prevent breakage.
|
||||
|
||||
PiperOrigin-RevId: 369056033
|
||||
Change-Id: Ic2018c297d3dd6f252dc1dd3667f1ed5cb1eaa42
|
||||
---
|
||||
.../core/kernels/matrix_diag_op.cc | 19 ++++++++++++++++---
|
||||
1 file changed, 16 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc
|
||||
index 69cc8170793ae..d4eb589836a85 100644
|
||||
--- a/tensorflow/core/kernels/matrix_diag_op.cc
|
||||
+++ b/tensorflow/core/kernels/matrix_diag_op.cc
|
||||
@@ -192,9 +192,22 @@ class MatrixDiagOp : public OpKernel {
|
||||
upper_diag_index = diag_index.flat<int32>()(1);
|
||||
}
|
||||
}
|
||||
- num_rows = context->input(2).flat<int32>()(0);
|
||||
- num_cols = context->input(3).flat<int32>()(0);
|
||||
- padding_value = context->input(4).flat<T>()(0);
|
||||
+
|
||||
+ auto& num_rows_tensor = context->input(2);
|
||||
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_rows_tensor.shape()),
|
||||
+ errors::InvalidArgument("num_rows must be a scalar"));
|
||||
+ num_rows = num_rows_tensor.flat<int32>()(0);
|
||||
+
|
||||
+ auto& num_cols_tensor = context->input(3);
|
||||
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_cols_tensor.shape()),
|
||||
+ errors::InvalidArgument("num_cols must be a scalar"));
|
||||
+ num_cols = num_cols_tensor.flat<int32>()(0);
|
||||
+
|
||||
+ auto& padding_value_tensor = context->input(4);
|
||||
+ OP_REQUIRES(context,
|
||||
+ TensorShapeUtils::IsScalar(padding_value_tensor.shape()),
|
||||
+ errors::InvalidArgument("padding_value must be a scalar"));
|
||||
+ padding_value = padding_value_tensor.flat<T>()(0);
|
||||
}
|
||||
|
||||
// Size validations.
|
||||
199
CVE-2021-29516-1.patch
Normal file
199
CVE-2021-29516-1.patch
Normal file
@ -0,0 +1,199 @@
|
||||
From ce47a396ff795bdb6cf48eb53dbcba46cb51fa7d Mon Sep 17 00:00:00 2001
|
||||
From: Katherine Tian <kattian@google.com>
|
||||
Date: Tue, 30 Jun 2020 04:12:11 +0000
|
||||
Subject: [PATCH 1/1] TensorKey class and TensorMap tests
|
||||
|
||||
---
|
||||
tensorflow/core/BUILD | 1 +
|
||||
tensorflow/core/framework/BUILD | 70 ++++++++++++++++++++++++++
|
||||
tensorflow/core/framework/tensor_key.h | 64 +++++++++++++++++++++++
|
||||
tensorflow/core/kernels/BUILD | 1 +
|
||||
4 files changed, 136 insertions(+)
|
||||
create mode 100644 tensorflow/core/framework/tensor_key.h
|
||||
|
||||
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
|
||||
index d0be6ee9..6e745b4e 100644
|
||||
--- a/tensorflow/core/BUILD
|
||||
+++ b/tensorflow/core/BUILD
|
||||
@@ -495,6 +495,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:shared_ptr_variant.h",
|
||||
"//tensorflow/core/framework:stats_aggregator.h",
|
||||
"//tensorflow/core/framework:tensor.h",
|
||||
+ "//tensorflow/core/framework:tensor_key.h",
|
||||
"//tensorflow/core/framework:tensor_shape.h",
|
||||
"//tensorflow/core/framework:tensor_slice.h",
|
||||
"//tensorflow/core/framework:tensor_types.h",
|
||||
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
|
||||
index 9b6ddb2a..093f0545 100644
|
||||
--- a/tensorflow/core/framework/BUILD
|
||||
+++ b/tensorflow/core/framework/BUILD
|
||||
@@ -209,6 +209,7 @@ filegroup(
|
||||
"shared_ptr_variant.h",
|
||||
"stats_aggregator.h",
|
||||
"tensor.h",
|
||||
+ "tensor_key.h",
|
||||
"tensor_reference.h",
|
||||
"tensor_shape.h",
|
||||
"tensor_slice.h",
|
||||
@@ -760,6 +761,75 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
+tf_cuda_library(
|
||||
+ name = "tensor_key",
|
||||
+ srcs = [
|
||||
+ "log_memory.cc",
|
||||
+ "tensor.cc",
|
||||
+ "typed_allocator.cc",
|
||||
+ "types.cc",
|
||||
+ "variant.cc",
|
||||
+ "variant_op_registry.cc",
|
||||
+ "variant_tensor_data.cc",
|
||||
+ ],
|
||||
+ hdrs = [
|
||||
+ "log_memory.h",
|
||||
+ "register_types.h",
|
||||
+ "tensor.h",
|
||||
+ "tensor_key.h",
|
||||
+ "typed_allocator.h",
|
||||
+ "types.h",
|
||||
+ "variant.h",
|
||||
+ "variant_encode_decode.h",
|
||||
+ "variant_op_registry.h",
|
||||
+ "variant_tensor_data.h",
|
||||
+ ],
|
||||
+ visibility = [
|
||||
+ "//tensorflow/core:__pkg__",
|
||||
+ "//tensorflow/core/util:__pkg__",
|
||||
+ ],
|
||||
+ deps = [
|
||||
+ ":allocation_description_proto_cc",
|
||||
+ ":allocator",
|
||||
+ ":bfloat16",
|
||||
+ ":log_memory_proto_cc",
|
||||
+ ":numeric_types",
|
||||
+ ":resource_handle",
|
||||
+ ":resource_handle_proto_cc",
|
||||
+ ":tensor_description_proto_cc",
|
||||
+ ":tensor_proto_cc",
|
||||
+ ":tensor_shape",
|
||||
+ ":tensor_types",
|
||||
+ ":type_index",
|
||||
+ ":type_traits",
|
||||
+ ":types_proto_cc",
|
||||
+ "//tensorflow/core/lib/core:coding",
|
||||
+ "//tensorflow/core/lib/core:errors",
|
||||
+ "//tensorflow/core/lib/core:refcount",
|
||||
+ "//tensorflow/core/lib/core:status",
|
||||
+ "//tensorflow/core/lib/core:stringpiece",
|
||||
+ "//tensorflow/core/lib/gtl:array_slice",
|
||||
+ "//tensorflow/core/lib/gtl:flatmap",
|
||||
+ "//tensorflow/core/lib/gtl:inlined_vector",
|
||||
+ "//tensorflow/core/lib/hash",
|
||||
+ "//tensorflow/core/lib/strings:str_util",
|
||||
+ "//tensorflow/core/lib/strings:strcat",
|
||||
+ "//tensorflow/core/platform:abi",
|
||||
+ "//tensorflow/core/platform:logging",
|
||||
+ "//tensorflow/core/platform:macros",
|
||||
+ "//tensorflow/core/platform:platform_port",
|
||||
+ "//tensorflow/core/platform:protobuf",
|
||||
+ "//tensorflow/core/platform:strcat",
|
||||
+ "//tensorflow/core/platform:tensor_coding",
|
||||
+ "//tensorflow/core/platform:types",
|
||||
+ "//tensorflow/core/public:version",
|
||||
+ "//third_party/eigen3",
|
||||
+ "@com_google_absl//absl/memory",
|
||||
+ "@com_google_absl//absl/strings",
|
||||
+ ],
|
||||
+ alwayslink = 1,
|
||||
+)
|
||||
+
|
||||
cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
diff --git a/tensorflow/core/framework/tensor_key.h b/tensorflow/core/framework/tensor_key.h
|
||||
new file mode 100644
|
||||
index 00000000..8eff58b2
|
||||
--- /dev/null
|
||||
+++ b/tensorflow/core/framework/tensor_key.h
|
||||
@@ -0,0 +1,64 @@
|
||||
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
+
|
||||
+Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+you may not use this file except in compliance with the License.
|
||||
+You may obtain a copy of the License at
|
||||
+
|
||||
+ http://www.apache.org/licenses/LICENSE-2.0
|
||||
+
|
||||
+Unless required by applicable law or agreed to in writing, software
|
||||
+distributed under the License is distributed on an "AS IS" BASIS,
|
||||
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
+See the License for the specific language governing permissions and
|
||||
+limitations under the License.
|
||||
+==============================================================================*/
|
||||
+
|
||||
+#include "tensorflow/core/framework/tensor.h"
|
||||
+
|
||||
+namespace tensorflow {
|
||||
+
|
||||
+class TensorKey : public Tensor {
|
||||
+ public:
|
||||
+ using Tensor::Tensor;
|
||||
+
|
||||
+ TensorKey(const Tensor& t) : Tensor(t) {}
|
||||
+
|
||||
+ // Equality operator. Needed for absl hashing.
|
||||
+ friend bool operator==(const TensorKey& t1, const TensorKey& t2) {
|
||||
+ if (t1.dtype() != t2.dtype() || t1.shape() != t2.shape()) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (DataTypeCanUseMemcpy(t1.dtype())) {
|
||||
+ return t1.tensor_data() == t2.tensor_data();
|
||||
+ }
|
||||
+ if (t1.dtype() == DT_STRING) {
|
||||
+ const auto s1 = t1.unaligned_flat<tstring>();
|
||||
+ const auto s2 = t2.unaligned_flat<tstring>();
|
||||
+ for (int64 i = 0, n = t1.NumElements(); i < n; ++i) {
|
||||
+ if (TF_PREDICT_FALSE(s1(i) != s2(i))) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ }
|
||||
+ return true;
|
||||
+ }
|
||||
+ return false;
|
||||
+ }
|
||||
+
|
||||
+ friend bool operator!=(const TensorKey& t1, const TensorKey& t2) {
|
||||
+ return !(t1==t2);
|
||||
+ }
|
||||
+
|
||||
+ // AbslHashValue() function, needed for absl hashing.
|
||||
+ template <typename H>
|
||||
+ friend H AbslHashValue(H h, const TensorKey& k) {
|
||||
+ uint8* d = (uint8*)(k.data());
|
||||
+ size_t s = k.AllocatedBytes();
|
||||
+ std::vector<uint8> vec;
|
||||
+ for (int i=0; i < s; i++) {
|
||||
+ vec.push_back(d[i]);
|
||||
+ }
|
||||
+ return H::combine(std::move(h), s);
|
||||
+ }
|
||||
+};
|
||||
+
|
||||
+} //namespace tensorflow
|
||||
\ No newline at end of file
|
||||
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
|
||||
index f5a480b3..4ef86efb 100644
|
||||
--- a/tensorflow/core/kernels/BUILD
|
||||
+++ b/tensorflow/core/kernels/BUILD
|
||||
@@ -3219,6 +3219,7 @@ tf_cc_tests(
|
||||
],
|
||||
deps = [
|
||||
":eigen_helpers",
|
||||
+ "//tensorflow/core/framework:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
--
|
||||
2.27.0
|
||||
|
||||
70
CVE-2021-29516-2.patch
Normal file
70
CVE-2021-29516-2.patch
Normal file
@ -0,0 +1,70 @@
|
||||
From b6a0cba2b381e83a1d0a19b675ca6f7459d2d2bc Mon Sep 17 00:00:00 2001
|
||||
From: Edward Loper <edloper@google.com>
|
||||
Date: Tue, 25 Aug 2020 08:12:53 -0700
|
||||
Subject: [PATCH 1/1] Fix segmentation fault in tf.map_fn when fn_output_spec
|
||||
is a RaggedTensorSpec and the input tensor has shape [0, ...].
|
||||
|
||||
PiperOrigin-RevId: 328332518
|
||||
Change-Id: I6aff03152bbc96507fb6c5f89b05722f3cc30164
|
||||
---
|
||||
.../kernels/ragged_tensor_from_variant_op.cc | 16 +++++++++++++++-
|
||||
.../python/ops/ragged/ragged_map_fn_op_test.py | 15 +++++++++++++++
|
||||
2 files changed, 30 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
index ad0712e6fd0..aa736ad7f60 100644
|
||||
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
@@ -175,8 +175,22 @@ Status NestedStackRaggedTensors(
|
||||
}
|
||||
}
|
||||
|
||||
+ // If the variant tensor input is empty, then we have no way to determine
|
||||
+ // the correct shape for the dense_values. (It must have rank>=1, and its
|
||||
+ // outer dimension must be 0, but we don't know its shape beyond that.)
|
||||
+ // For now, we just use a shape of `[0]` in this case.
|
||||
+ // TODO(edloper): Update this op with an attribute containing information
|
||||
+ // about dense_values shape. If it's `None`, then we'll probably still have
|
||||
+ // to use shape=[0] here, but if we have more info, then we can use it.
|
||||
+ // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
|
||||
+ TensorShape component_values_shape;
|
||||
+ if (ragged_components.empty()) {
|
||||
+ component_values_shape = TensorShape({0});
|
||||
+ } else {
|
||||
+ component_values_shape = ragged_components[0].values.shape();
|
||||
+ }
|
||||
+
|
||||
// Populate values.
|
||||
- TensorShape component_values_shape = ragged_components[0].values.shape();
|
||||
int values_size = component_values_shape.dim_size(0);
|
||||
for (int i = 1; i < ragged_components.size(); i++) {
|
||||
if (ragged_components[i].values.dims() != component_values_shape.dims()) {
|
||||
diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
index 8a40e396a68..bead4923a0a 100644
|
||||
--- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
+++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
@@ -150,6 +150,21 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
result_dtype=ragged_tensor.RaggedTensorType(
|
||||
dtype=dtypes.int64, ragged_rank=4),
|
||||
),
|
||||
+ # [d1] -> [d1, (d2), (d3)]
|
||||
+ dict(
|
||||
+ fn=ragged_math_ops.range,
|
||||
+ elems=np.array([1, 2, 3], np.int64),
|
||||
+ expected_output=[[[0]], [[0, 1]], [[0, 1, 2]]],
|
||||
+ result_dtype=ragged_tensor.RaggedTensorType(
|
||||
+ dtype=dtypes.int64, ragged_rank=2)),
|
||||
+ # [0] -> [0, (d2), (d3)] (github issue #36232)
|
||||
+ dict(
|
||||
+ fn=ragged_math_ops.range,
|
||||
+ elems=np.zeros([0], np.int64),
|
||||
+ expected_output=[],
|
||||
+ expected_ragged_rank=2,
|
||||
+ result_dtype=ragged_tensor.RaggedTensorType(
|
||||
+ dtype=dtypes.int64, ragged_rank=2)),
|
||||
])
|
||||
|
||||
def testRaggedMap(
|
||||
--
|
||||
2.27.0
|
||||
|
||||
904
CVE-2021-29516-3.patch
Normal file
904
CVE-2021-29516-3.patch
Normal file
@ -0,0 +1,904 @@
|
||||
From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001
|
||||
From: Edward Loper <edloper@google.com>
|
||||
Date: Wed, 14 Oct 2020 09:41:17 -0700
|
||||
Subject: [PATCH 1/1] Added gradients for RaggedTensorToVariant and
|
||||
RaggedTensorFromVariant. (This allows gradients to pass through map_fn when
|
||||
it is applied to ragged tensors.)
|
||||
|
||||
PiperOrigin-RevId: 337108621
|
||||
Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
|
||||
---
|
||||
tensorflow/core/kernels/BUILD | 15 ++
|
||||
.../kernels/ragged_tensor_from_variant_op.cc | 164 +++++++---------
|
||||
.../kernels/ragged_tensor_to_variant_op.cc | 180 +++++++++++-------
|
||||
.../core/kernels/ragged_tensor_variant.cc | 86 +++++++++
|
||||
.../core/kernels/ragged_tensor_variant.h | 110 +++++++++++
|
||||
tensorflow/core/ops/ragged_conversion_ops.cc | 20 +-
|
||||
tensorflow/python/ops/ragged/BUILD | 1 +
|
||||
9 files changed, 478 insertions(+), 172 deletions(-)
|
||||
create mode 100644 tensorflow/core/framework/tensor_key.h
|
||||
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc
|
||||
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h
|
||||
|
||||
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
|
||||
index f5a480b3..12adb2b2 100644
|
||||
--- a/tensorflow/core/kernels/BUILD
|
||||
+++ b/tensorflow/core/kernels/BUILD
|
||||
@@ -1529,10 +1529,22 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
+cc_library(
|
||||
+ name = "ragged_tensor_variant",
|
||||
+ srcs = ["ragged_tensor_variant.cc"],
|
||||
+ hdrs = ["ragged_tensor_variant.h"],
|
||||
+ deps = [
|
||||
+ ":cwise_op",
|
||||
+ "//tensorflow/core:framework",
|
||||
+ ],
|
||||
+)
|
||||
+
|
||||
tf_kernel_library(
|
||||
name = "ragged_tensor_to_variant_op",
|
||||
srcs = ["ragged_tensor_to_variant_op.cc"],
|
||||
deps = [
|
||||
+ ":concat_lib",
|
||||
+ ":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@@ -1542,6 +1554,7 @@ tf_kernel_library(
|
||||
name = "ragged_tensor_from_variant_op",
|
||||
srcs = ["ragged_tensor_from_variant_op.cc"],
|
||||
deps = [
|
||||
+ ":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@@ -1554,6 +1567,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ragged_tensor_to_variant_op",
|
||||
+ ":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@@ -1570,6 +1584,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ragged_tensor_from_variant_op",
|
||||
+ ":ragged_tensor_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
index d7b6a89a..fa8853af 100644
|
||||
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
@@ -20,110 +20,76 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
-struct RaggedTensor {
|
||||
- Tensor values;
|
||||
- std::vector<Tensor> nested_splits;
|
||||
-};
|
||||
-
|
||||
-Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
|
||||
- int ragged_rank, DataType value_dtype,
|
||||
- DataType split_dtype,
|
||||
- std::vector<RaggedTensor>* decoded_ragged) {
|
||||
+Status RaggedComponentsFromVariant(
|
||||
+ const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
|
||||
+ DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
|
||||
const auto& flat_variants = encoded_variant.flat<Variant>();
|
||||
- decoded_ragged->resize(flat_variants.size());
|
||||
- // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
|
||||
- // input.
|
||||
+ decoded_ragged->reserve(flat_variants.size());
|
||||
+
|
||||
for (int i = 0; i < flat_variants.size(); i++) {
|
||||
const auto& flat_variant = flat_variants(i);
|
||||
- const Tensor* encoded_list = flat_variant.get<Tensor>();
|
||||
- if (encoded_list == nullptr) {
|
||||
+ const RaggedTensorVariant* decoded =
|
||||
+ flat_variant.get<RaggedTensorVariant>();
|
||||
+ if (decoded == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Input Variant element at index ", i,
|
||||
- " doesn't hold a Tensor: ", flat_variant.DebugString());
|
||||
+ " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
|
||||
}
|
||||
- if (encoded_list->dims() != 1) {
|
||||
+ decoded_ragged->push_back(*decoded);
|
||||
+ decoded = &decoded_ragged->back();
|
||||
+ // Check ragged rank & types
|
||||
+ if (decoded->ragged_rank() != ragged_rank) {
|
||||
return errors::InvalidArgument(
|
||||
- "Encoded input Variant must have rank 1, but found rank: ",
|
||||
- encoded_list->dims(),
|
||||
- ". encoded input Variant: ", encoded_list->DebugString());
|
||||
+ "Encoded input RaggedTensorVariant has ragged_rank=",
|
||||
+ decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
|
||||
}
|
||||
- if (encoded_list->NumElements() != (ragged_rank + 1) &&
|
||||
- encoded_list->NumElements() != 1) {
|
||||
+ if (decoded->values().dtype() != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
- "Encoded input Variant must hold either input_ragged_rank + 1 "
|
||||
- "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
|
||||
- "input_ragged_rank: ",
|
||||
- ragged_rank,
|
||||
- ", encoded input Variant: ", encoded_list->DebugString());
|
||||
+ "Expected values Tensor dtype: ", DataTypeString(value_dtype),
|
||||
+ ", found: ", DataTypeString(decoded->values().dtype()));
|
||||
}
|
||||
- const auto& input_vec = encoded_list->vec<Variant>();
|
||||
-
|
||||
- // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
|
||||
- // to create the component RaggedTensors.
|
||||
- (*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
|
||||
- for (int j = 0; j < ragged_rank; j++) {
|
||||
- const Tensor* split_tensor = input_vec(j).get<Tensor>();
|
||||
- if (split_tensor == nullptr) {
|
||||
- return errors::InvalidArgument(
|
||||
- "Encoded scalar element at index ", i,
|
||||
- " doesn't have a splits Tensor at split_index ", j, ": ",
|
||||
- input_vec(j).DebugString());
|
||||
- }
|
||||
- Tensor splits_tensor = *split_tensor;
|
||||
- if (splits_tensor.dtype() != split_dtype) {
|
||||
+ if (decoded->values().dims() < 1) {
|
||||
+ return errors::InvalidArgument(
|
||||
+ "Ragged values must have rank >= 1; encoded scalar element at index ",
|
||||
+ i, " has values Tensor: ", decoded->values().DebugString());
|
||||
+ }
|
||||
+ for (const auto& splits : decoded->nested_splits()) {
|
||||
+ if (splits.dtype() != split_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
- "Expected splits Tensor dtype: ", split_dtype,
|
||||
- ", found: ", splits_tensor.dtype());
|
||||
+ "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
|
||||
+ ", found: ", DataTypeString(splits.dtype()));
|
||||
}
|
||||
- if (splits_tensor.dims() != 1) {
|
||||
+ if (splits.dims() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Ragged splits must have rank 1; encoded scalar element at index ",
|
||||
- i, " has splits Tensor at split_index ", j, ": ",
|
||||
- splits_tensor.DebugString());
|
||||
+ i, " has splits Tensor ", splits.DebugString());
|
||||
}
|
||||
- (*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
|
||||
- }
|
||||
- const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
|
||||
- if (values_tensor == nullptr) {
|
||||
- return errors::InvalidArgument("Encoded scalar element at index ", i,
|
||||
- " doesn't have a values Tensor: ",
|
||||
- input_vec(ragged_rank).DebugString());
|
||||
- }
|
||||
- if (values_tensor->dtype() != value_dtype) {
|
||||
- return errors::InvalidArgument(
|
||||
- "Expected values Tensor dtype: ", DataTypeString(value_dtype),
|
||||
- ", found: ", DataTypeString(values_tensor->dtype()));
|
||||
- }
|
||||
- if (values_tensor->dims() < 1) {
|
||||
- return errors::InvalidArgument(
|
||||
- "Ragged values must have rank >= 1; encoded scalar element at index ",
|
||||
- i, " has values Tensor: ", values_tensor->DebugString());
|
||||
}
|
||||
- (*decoded_ragged)[i].values = *values_tensor;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
Status NestedStackRaggedTensors(
|
||||
- const std::vector<RaggedTensor>& ragged_components,
|
||||
+ const std::vector<RaggedTensorVariant>& ragged_components,
|
||||
const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
|
||||
- const int output_ragged_rank, RaggedTensor* output_ragged) {
|
||||
- output_ragged->nested_splits.reserve(output_ragged_rank);
|
||||
+ const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
|
||||
+ output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
|
||||
const int dims = nested_dim_sizes.size();
|
||||
|
||||
// Populate first `dims - 1` splits.
|
||||
for (int i = 0; i < dims - 1; i++) {
|
||||
int dims_splits_size = nested_dim_sizes[i] + 1;
|
||||
- output_ragged->nested_splits.push_back(Tensor(
|
||||
- DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
|
||||
- auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
|
||||
+ output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
|
||||
+ TensorShape({dims_splits_size})));
|
||||
+ auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
|
||||
int split_diff = nested_dim_sizes[i + 1];
|
||||
for (int j = 0; j < dims_splits_size; j++) {
|
||||
splits_vec(j) = j * split_diff;
|
||||
@@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
|
||||
|
||||
// Populate `dims`-th split.
|
||||
int splits_size = ragged_components.size() + 1;
|
||||
- output_ragged->nested_splits.push_back(
|
||||
+ output_ragged->append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
|
||||
auto dims_splits_vec =
|
||||
- output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
|
||||
+ output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
|
||||
dims_splits_vec(0) = 0;
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
- int split_val = ragged_components[i].values.shape().dim_size(0);
|
||||
- if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
|
||||
- split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
|
||||
+ int split_val = ragged_components[i].values().shape().dim_size(0);
|
||||
+ if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
|
||||
+ split_val = ragged_components[i].splits(0).NumElements() - 1;
|
||||
}
|
||||
dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
|
||||
}
|
||||
@@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
|
||||
int split_index = dims + i;
|
||||
int split_size = 1;
|
||||
for (int j = 0; j < ragged_components.size(); j++) {
|
||||
- if (!ragged_components[j].nested_splits.empty()) {
|
||||
- split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
|
||||
+ if (!ragged_components[j].nested_splits().empty()) {
|
||||
+ split_size += ragged_components[j].splits(i).NumElements() - 1;
|
||||
}
|
||||
}
|
||||
- output_ragged->nested_splits.push_back(
|
||||
+ output_ragged->append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||||
auto splits_vec =
|
||||
- output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
|
||||
+ output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
|
||||
splits_vec(0) = 0;
|
||||
SPLIT_TYPE last_split_value = 0;
|
||||
int index = 1;
|
||||
for (int j = 0; j < ragged_components.size(); j++) {
|
||||
- if (ragged_components[j].nested_splits.empty()) {
|
||||
+ if (ragged_components[j].nested_splits().empty()) {
|
||||
// Corner case: empty row. e.g [ [[x], [x]], [] ]
|
||||
continue;
|
||||
}
|
||||
auto component_splits_vec =
|
||||
- ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
|
||||
+ ragged_components[j].splits(i).vec<SPLIT_TYPE>();
|
||||
for (int k = 1; k < component_splits_vec.size(); k++, index++) {
|
||||
splits_vec(index) = component_splits_vec(k) + last_split_value;
|
||||
}
|
||||
@@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
|
||||
if (ragged_components.empty()) {
|
||||
component_values_shape = TensorShape({0});
|
||||
} else {
|
||||
- component_values_shape = ragged_components[0].values.shape();
|
||||
+ component_values_shape = ragged_components[0].values().shape();
|
||||
}
|
||||
|
||||
// Populate values.
|
||||
int values_size = component_values_shape.dim_size(0);
|
||||
for (int i = 1; i < ragged_components.size(); i++) {
|
||||
- if (ragged_components[i].values.dims() != component_values_shape.dims()) {
|
||||
+ if (ragged_components[i].values().dims() != component_values_shape.dims()) {
|
||||
return errors::InvalidArgument(
|
||||
"Rank of values must match for all "
|
||||
"components; values shape at index 0: ",
|
||||
component_values_shape.DebugString(), ", values shape at index ", i,
|
||||
- ": ", ragged_components[i].values.shape().DebugString());
|
||||
+ ": ", ragged_components[i].values().shape().DebugString());
|
||||
}
|
||||
- values_size += ragged_components[i].values.shape().dim_size(0);
|
||||
+ values_size += ragged_components[i].values().shape().dim_size(0);
|
||||
}
|
||||
component_values_shape.set_dim(0, values_size);
|
||||
- output_ragged->values =
|
||||
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
|
||||
+ output_ragged->set_values(
|
||||
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
|
||||
auto output_values_flat =
|
||||
- output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
|
||||
+ output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int values_index = 0;
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
auto component_values_flat =
|
||||
- ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
|
||||
- int num_inner_elements = ragged_components[i].values.NumElements();
|
||||
- if (ragged_components[i].values.dim_size(0) > 0) {
|
||||
- num_inner_elements /= ragged_components[i].values.dim_size(0);
|
||||
+ ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
|
||||
+ int num_inner_elements = ragged_components[i].values().NumElements();
|
||||
+ if (ragged_components[i].values().dim_size(0) > 0) {
|
||||
+ num_inner_elements /= ragged_components[i].values().dim_size(0);
|
||||
}
|
||||
- for (int j = 0; j < ragged_components[i].values.dim_size(0);
|
||||
+ for (int j = 0; j < ragged_components[i].values().dim_size(0);
|
||||
j++, values_index++) {
|
||||
for (int k = 0; k < num_inner_elements; k++) {
|
||||
output_values_flat(values_index, k) = component_values_flat(j, k);
|
||||
@@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
// Decode all variants.
|
||||
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||||
const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
|
||||
- std::vector<RaggedTensor> decoded_components;
|
||||
+ std::vector<RaggedTensorVariant> decoded_components;
|
||||
OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
|
||||
encoded_variant, input_ragged_rank_,
|
||||
value_dtype, split_dtype, &decoded_components));
|
||||
@@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
for (int i = 0; i < encoded_variant.dims(); i++) {
|
||||
encoded_dim_sizes[i] = encoded_variant.dim_size(i);
|
||||
}
|
||||
- RaggedTensor output_ragged;
|
||||
+ RaggedTensorVariant output_ragged;
|
||||
OP_REQUIRES_OK(
|
||||
context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
|
||||
decoded_components, encoded_dim_sizes, input_ragged_rank_,
|
||||
@@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||||
int output_ragged_rank_;
|
||||
|
||||
void ReturnRaggedTensor(OpKernelContext* context,
|
||||
- RaggedTensor ragged_tensor) {
|
||||
- int ragged_rank = ragged_tensor.nested_splits.size();
|
||||
+ const RaggedTensorVariant& ragged_tensor) {
|
||||
+ int ragged_rank = ragged_tensor.ragged_rank();
|
||||
OpOutputList splits_out;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->output_list("output_nested_splits", &splits_out));
|
||||
for (int i = 0; i < ragged_rank; i++) {
|
||||
- splits_out.set(i, ragged_tensor.nested_splits[i]);
|
||||
+ splits_out.set(i, ragged_tensor.splits(i));
|
||||
}
|
||||
- context->set_output(ragged_rank, ragged_tensor.values);
|
||||
+ context->set_output(ragged_rank, ragged_tensor.values());
|
||||
}
|
||||
};
|
||||
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
index 3190534b..a60e5c62 100644
|
||||
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
@@ -18,50 +18,38 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
+#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
+#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
+#include "tensorflow/core/kernels/concat_lib.h"
|
||||
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
+#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
-struct RaggedTensor {
|
||||
- Tensor values;
|
||||
- std::vector<Tensor> nested_splits;
|
||||
-};
|
||||
-
|
||||
-Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
|
||||
- // Encode as a rank-1 Variant Tensor.
|
||||
- int ragged_rank = ragged.nested_splits.size();
|
||||
- *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
|
||||
- auto encoded_vec = encoded_list->vec<Variant>();
|
||||
- for (int i = 0; i < ragged_rank; i++) {
|
||||
- encoded_vec(i) = ragged.nested_splits[i];
|
||||
- }
|
||||
- encoded_vec(ragged_rank) = ragged.values;
|
||||
- return Status::OK();
|
||||
-}
|
||||
-
|
||||
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
-Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
- std::vector<RaggedTensor>* ragged_components) {
|
||||
+Status UnbatchRaggedZerothDim(
|
||||
+ const RaggedTensorVariant& batched_ragged,
|
||||
+ std::vector<RaggedTensorVariant>* ragged_components) {
|
||||
// Set up the component Ragged Tensors.
|
||||
- int ragged_rank = batched_ragged.nested_splits.size();
|
||||
- auto batched_splits_top_vec =
|
||||
- batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
|
||||
+ int ragged_rank = batched_ragged.ragged_rank();
|
||||
+ auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
|
||||
int num_components = batched_splits_top_vec.size() - 1;
|
||||
int num_splits = ragged_rank - 1;
|
||||
ragged_components->resize(num_components);
|
||||
- for (RaggedTensor ragged_component : *ragged_components) {
|
||||
- ragged_component.nested_splits.reserve(num_splits);
|
||||
+ for (RaggedTensorVariant& ragged_component : *ragged_components) {
|
||||
+ ragged_component.mutable_nested_splits()->reserve(num_splits);
|
||||
}
|
||||
- const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
|
||||
- int num_inner_elems = batched_ragged.values.NumElements();
|
||||
- if (batched_ragged.values.dim_size(0) > 1) {
|
||||
- num_inner_elems /= batched_ragged.values.dim_size(0);
|
||||
+ const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
|
||||
+ int num_inner_elems = batched_ragged.values().NumElements();
|
||||
+ if (batched_ragged.values().dim_size(0) > 1) {
|
||||
+ num_inner_elems /= batched_ragged.values().dim_size(0);
|
||||
}
|
||||
- TensorShape values_shape = batched_ragged.values.shape();
|
||||
+ TensorShape values_shape = batched_ragged.values().shape();
|
||||
|
||||
// Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
|
||||
if (num_splits == 0) {
|
||||
@@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
int limit = batched_splits_top_vec(i + 1);
|
||||
int num_values = limit - start;
|
||||
values_shape.set_dim(0, num_values);
|
||||
- (*ragged_components)[i].values =
|
||||
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||||
+ (*ragged_components)[i].set_values(
|
||||
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||||
auto ragged_component_values_flat =
|
||||
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||||
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||||
for (int j = 0; j < num_values * num_inner_elems; j++) {
|
||||
ragged_component_values_flat(j) =
|
||||
batched_flat(j + start * num_inner_elems);
|
||||
@@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
|
||||
batched_splits_vec.reserve(ragged_rank);
|
||||
for (int i = 0; i < ragged_rank; i++) {
|
||||
- batched_splits_vec.push_back(
|
||||
- batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
|
||||
+ batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
|
||||
}
|
||||
std::vector<int> index(num_splits, 1);
|
||||
std::vector<int> ragged_component_values_size(num_components, 0);
|
||||
@@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
int last_index = ragged_component_splits_vec[j - 1].size() - 1;
|
||||
split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
|
||||
}
|
||||
- (*ragged_components)[i].nested_splits.push_back(
|
||||
+ (*ragged_components)[i].append_splits(
|
||||
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||||
ragged_component_splits_vec.push_back(
|
||||
- (*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
|
||||
+ (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
|
||||
SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
|
||||
ragged_component_splits_vec[j](0) = 0;
|
||||
for (int k = 1; k < split_size; k++, index[j]++) {
|
||||
@@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
int num_values = ragged_component_values_size[i];
|
||||
values_shape.set_dim(0, num_values);
|
||||
- (*ragged_components)[i].values =
|
||||
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||||
+ (*ragged_components)[i].set_values(
|
||||
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||||
auto ragged_component_values_flat =
|
||||
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||||
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||||
for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
|
||||
ragged_component_values_flat(j) = batched_flat(value_index);
|
||||
}
|
||||
@@ -152,24 +139,21 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
||||
&ragged_nested_splits_in));
|
||||
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
||||
- RaggedTensor batched_ragged_input;
|
||||
+ RaggedTensorVariant batched_ragged_input;
|
||||
// Read ragged_values input.
|
||||
- batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
||||
- batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
|
||||
+ batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
|
||||
+ batched_ragged_input.mutable_nested_splits()->reserve(
|
||||
+ ragged_nested_splits_len);
|
||||
for (int i = 0; i < ragged_nested_splits_len; i++) {
|
||||
- batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
|
||||
+ batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
|
||||
}
|
||||
|
||||
if (!batched_input_) {
|
||||
- // Encode the input as is.
|
||||
- Tensor encoded_list;
|
||||
- OP_REQUIRES_OK(context,
|
||||
- RaggedToVariant(batched_ragged_input, &encoded_list));
|
||||
// Encode as a Scalar Variant Tensor.
|
||||
Tensor* encoded_scalar;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
|
||||
&encoded_scalar));
|
||||
- encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
|
||||
+ encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -180,24 +164,19 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
"received rt_nested_splits of length 0."));
|
||||
|
||||
// Unbatch the Ragged Tensor and encode the components.
|
||||
- std::vector<RaggedTensor> ragged_components;
|
||||
+ std::vector<RaggedTensorVariant> unbatched_ragged_input;
|
||||
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
|
||||
- batched_ragged_input, &ragged_components));
|
||||
- std::vector<Tensor> encoded_components(ragged_components.size());
|
||||
- for (int i = 0; i < ragged_components.size(); i++) {
|
||||
- OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
|
||||
- &encoded_components[i]));
|
||||
- }
|
||||
+ batched_ragged_input, &unbatched_ragged_input));
|
||||
|
||||
// Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
|
||||
- Tensor* encoded_ragged;
|
||||
- int output_size = ragged_components.size();
|
||||
+ Tensor* encoded_vector;
|
||||
+ int output_size = unbatched_ragged_input.size();
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({output_size}),
|
||||
- &encoded_ragged));
|
||||
- auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
|
||||
+ &encoded_vector));
|
||||
+ auto encoded_vector_t = encoded_vector->vec<Variant>();
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
- encoded_ragged_vec(i) = encoded_components[i];
|
||||
+ encoded_vector_t(i) = unbatched_ragged_input[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,12 +184,81 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
bool batched_input_;
|
||||
};
|
||||
|
||||
-#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||||
- REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||||
- .Device(DEVICE_CPU) \
|
||||
- .TypeConstraint<value_type>("Tvalues") \
|
||||
- .TypeConstraint<split_type>("Tsplits"), \
|
||||
- RaggedTensorToVariantOp<value_type, split_type>);
|
||||
+template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||||
+class RaggedTensorToVariantGradientOp : public OpKernel {
|
||||
+ public:
|
||||
+ using OpKernel::OpKernel;
|
||||
+
|
||||
+ void Compute(OpKernelContext* context) override {
|
||||
+ // Read inputs.
|
||||
+ Tensor encoded_variant = context->input(0);
|
||||
+ Tensor row_splits = context->input(1);
|
||||
+ auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
|
||||
+ TensorShape dense_values_shape;
|
||||
+ OP_REQUIRES_OK(context,
|
||||
+ TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
|
||||
+ &dense_values_shape));
|
||||
+
|
||||
+ const auto& flat_variants = encoded_variant.flat<Variant>();
|
||||
+
|
||||
+ // Get a Tensor containing the flat_values for each variant.
|
||||
+ std::vector<Tensor> values;
|
||||
+ for (int i = 0; i < flat_variants.size(); ++i) {
|
||||
+ if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
|
||||
+ values.push_back(encoded->values());
|
||||
+ } else {
|
||||
+ // Missing value: this happens if only some of the variant values
|
||||
+ // generated by ragged_tensor_to_variant impacted the value that we're
|
||||
+ // calculating the gradient for. In this case, we will see a
|
||||
+ // default-constructed variant; so treat it as a zero tensor with the
|
||||
+ // appropriate shape.
|
||||
+ const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||||
+ int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
|
||||
+ TensorShape zeros_shape = dense_values_shape;
|
||||
+ zeros_shape.set_dim(0, piece_size);
|
||||
+ Tensor zero(value_dtype, zeros_shape);
|
||||
+ zero.flat<VALUE_TYPE>() =
|
||||
+ zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
|
||||
+ values.push_back(zero);
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ if (values.size() == 1) {
|
||||
+ // Just one flat_value tensor: return as-is.
|
||||
+ context->set_output(0, values[0]);
|
||||
+ } else {
|
||||
+ // Multiple flat_values tensors: concatenate them together.
|
||||
+ using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
|
||||
+ using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
|
||||
+ std::vector<std::unique_ptr<ConstPiece>> pieces;
|
||||
+ pieces.reserve(values.size());
|
||||
+ for (const Tensor& t : values) {
|
||||
+ pieces.emplace_back(
|
||||
+ new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
|
||||
+ }
|
||||
+ Tensor* out = nullptr;
|
||||
+ OP_REQUIRES_OK(context,
|
||||
+ context->allocate_output(0, dense_values_shape, &out));
|
||||
+ Piece out_flat =
|
||||
+ out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
|
||||
+ ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
|
||||
+ }
|
||||
+ }
|
||||
+};
|
||||
+
|
||||
+#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||||
+ REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||||
+ .Device(DEVICE_CPU) \
|
||||
+ .TypeConstraint<value_type>("Tvalues") \
|
||||
+ .TypeConstraint<split_type>("Tsplits"), \
|
||||
+ RaggedTensorToVariantOp<value_type, split_type>); \
|
||||
+ REGISTER_KERNEL_BUILDER( \
|
||||
+ Name("RaggedTensorToVariantGradient") \
|
||||
+ .Device(DEVICE_CPU) \
|
||||
+ .TypeConstraint<value_type>("Tvalues") \
|
||||
+ .TypeConstraint<split_type>("Tsplits"), \
|
||||
+ RaggedTensorToVariantGradientOp<value_type, split_type>);
|
||||
+
|
||||
#define REGISTER_KERNELS(value_type) \
|
||||
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
|
||||
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc
|
||||
new file mode 100644
|
||||
index 00000000..94663138
|
||||
--- /dev/null
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_variant.cc
|
||||
@@ -0,0 +1,86 @@
|
||||
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
+
|
||||
+Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+you may not use this file except in compliance with the License.
|
||||
+You may obtain a copy of the License at
|
||||
+
|
||||
+ http://www.apache.org/licenses/LICENSE-2.0
|
||||
+
|
||||
+Unless required by applicable law or agreed to in writing, software
|
||||
+distributed under the License is distributed on an "AS IS" BASIS,
|
||||
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
+See the License for the specific language governing permissions and
|
||||
+limitations under the License.
|
||||
+==============================================================================*/
|
||||
+
|
||||
+#define EIGEN_USE_THREADS
|
||||
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
+#define EIGEN_USE_GPU
|
||||
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
+
|
||||
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||||
+
|
||||
+namespace tensorflow {
|
||||
+
|
||||
+string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
|
||||
+
|
||||
+string RaggedTensorVariant::DebugString() const {
|
||||
+ return absl::StrCat(
|
||||
+ "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
|
||||
+ ", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
|
||||
+ DataTypeString(nested_splits_.empty() ? DT_INVALID
|
||||
+ : nested_splits_.back().dtype()));
|
||||
+}
|
||||
+
|
||||
+void RaggedTensorVariant::Encode(VariantTensorData* data) const {
|
||||
+ data->set_type_name(TypeName());
|
||||
+ for (const auto& splits : nested_splits_) {
|
||||
+ *data->add_tensors() = splits;
|
||||
+ }
|
||||
+ *data->add_tensors() = values_;
|
||||
+}
|
||||
+
|
||||
+bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
|
||||
+ if (data.tensors_size() < 1) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ nested_splits_.assign(data.tensors().begin(),
|
||||
+ std::prev(data.tensors().end()));
|
||||
+ values_ = data.tensors().back();
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
+namespace {
|
||||
+
|
||||
+Status RaggedTensorVariantDeviceCopy(
|
||||
+ const RaggedTensorVariant& from, RaggedTensorVariant* to,
|
||||
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
+ TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
|
||||
+ // TODO(b/170415165) Should we use `copy` to move splits from device<->host?
|
||||
+ *to->mutable_nested_splits() = from.nested_splits();
|
||||
+ return Status::OK();
|
||||
+}
|
||||
+
|
||||
+} // namespace
|
||||
+
|
||||
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||||
+ RaggedTensorVariantZerosLike<CPUDevice>);
|
||||
+
|
||||
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||||
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||||
+ RaggedTensorVariantBinaryAdd<CPUDevice>);
|
||||
+
|
||||
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
|
||||
+ "RaggedTensorVariant");
|
||||
+
|
||||
+#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \
|
||||
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||||
+ RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
|
||||
+
|
||||
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(
|
||||
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
+
|
||||
+} // namespace tensorflow
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h
|
||||
new file mode 100644
|
||||
index 00000000..730758a3
|
||||
--- /dev/null
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_variant.h
|
||||
@@ -0,0 +1,110 @@
|
||||
+#include "tensorflow/core/framework/tensor_key.h"
|
||||
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
+
|
||||
+Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+you may not use this file except in compliance with the License.
|
||||
+You may obtain a copy of the License at
|
||||
+
|
||||
+ http://www.apache.org/licenses/LICENSE-2.0
|
||||
+
|
||||
+Unless required by applicable law or agreed to in writing, software
|
||||
+distributed under the License is distributed on an "AS IS" BASIS,
|
||||
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
+See the License for the specific language governing permissions and
|
||||
+limitations under the License.
|
||||
+==============================================================================*/
|
||||
+
|
||||
+#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||||
+#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||||
+
|
||||
+#define EIGEN_USE_THREADS
|
||||
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
+#define EIGEN_USE_GPU
|
||||
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
+
|
||||
+#include <vector>
|
||||
+
|
||||
+#include "tensorflow/core/framework/tensor.h"
|
||||
+#include "tensorflow/core/framework/types.h"
|
||||
+#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
+#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
+#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
+#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
+
|
||||
+namespace tensorflow {
|
||||
+
|
||||
+// Class used to store a RaggedTensor as a Variant scalar.
|
||||
+class RaggedTensorVariant {
|
||||
+ public:
|
||||
+ RaggedTensorVariant() {}
|
||||
+ RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
|
||||
+ : values_(std::move(values)), nested_splits_(nested_splits) {}
|
||||
+
|
||||
+ // Variant support methods.
|
||||
+ string TypeName() const;
|
||||
+ string DebugString() const;
|
||||
+ void Encode(VariantTensorData* data) const;
|
||||
+ bool Decode(const VariantTensorData& data);
|
||||
+
|
||||
+ // The flat_values of the RaggedTensor.
|
||||
+ const Tensor& values() const { return values_; }
|
||||
+ Tensor* mutable_values() { return &values_; }
|
||||
+ void set_values(const Tensor& new_values) { values_ = new_values; }
|
||||
+
|
||||
+ // The nested row_splits of the RaggedTensor.
|
||||
+ int ragged_rank() const { return nested_splits_.size(); }
|
||||
+ const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
|
||||
+ std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
|
||||
+ const Tensor& splits(int i) const { return nested_splits_[i]; }
|
||||
+ Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
|
||||
+ void set_nested_splits(const std::vector<Tensor>& nested_splits) {
|
||||
+ nested_splits_ = nested_splits;
|
||||
+ }
|
||||
+ void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
|
||||
+
|
||||
+ private:
|
||||
+ Tensor values_;
|
||||
+ std::vector<Tensor> nested_splits_;
|
||||
+};
|
||||
+
|
||||
+template <typename Device>
|
||||
+Status RaggedTensorVariantZerosLike(OpKernelContext* c,
|
||||
+ const RaggedTensorVariant& x,
|
||||
+ RaggedTensorVariant* y) {
|
||||
+ y->set_nested_splits(x.nested_splits());
|
||||
+ TF_RETURN_IF_ERROR(
|
||||
+ ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
|
||||
+ return Status::OK();
|
||||
+}
|
||||
+
|
||||
+template <typename Device>
|
||||
+Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
|
||||
+ const RaggedTensorVariant& x,
|
||||
+ const RaggedTensorVariant& y,
|
||||
+ RaggedTensorVariant* out) {
|
||||
+ if (x.values().dtype() != y.values().dtype()) {
|
||||
+ return errors::InvalidArgument(
|
||||
+ "Can't add RaggedTensorVariants of different dtypes. One is ",
|
||||
+ DataTypeString(x.values().dtype()), " and the other is ",
|
||||
+ DataTypeString(y.values().dtype()));
|
||||
+ }
|
||||
+ if (x.ragged_rank() != y.ragged_rank()) {
|
||||
+ return errors::InvalidArgument(
|
||||
+ "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
|
||||
+ x.ragged_rank(), " and the other is ", y.ragged_rank());
|
||||
+ }
|
||||
+ for (int i = 0; i < x.ragged_rank(); ++i) {
|
||||
+ if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
|
||||
+ return errors::InvalidArgument(
|
||||
+ "Can't add RaggedTensorVariants with different row_splits.");
|
||||
+ }
|
||||
+ }
|
||||
+ out->set_nested_splits(x.nested_splits());
|
||||
+ TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
|
||||
+ out->mutable_values()));
|
||||
+ return Status::OK();
|
||||
+}
|
||||
+
|
||||
+} // namespace tensorflow
|
||||
+
|
||||
+#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||||
diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc
|
||||
index 6bee189c..8512bcf3 100644
|
||||
--- a/tensorflow/core/ops/ragged_conversion_ops.cc
|
||||
+++ b/tensorflow/core/ops/ragged_conversion_ops.cc
|
||||
@@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
|
||||
Status RaggedTensorToSparseShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorToVariantShapeFn(InferenceContext* c);
|
||||
Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
|
||||
-tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||||
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
|
||||
+Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||||
|
||||
//==============================================================================
|
||||
// Registered Ops
|
||||
@@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
||||
|
||||
+REGISTER_OP("RaggedTensorToVariantGradient")
|
||||
+ .Input("encoded_ragged_grad: variant")
|
||||
+ .Input("row_splits: Tsplits")
|
||||
+ .Input("dense_values_shape: int32")
|
||||
+ .Output("dense_values_grad: Tvalues")
|
||||
+ .Attr("Tvalues: type")
|
||||
+ .Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
+ .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
|
||||
+
|
||||
REGISTER_OP("RaggedTensorToTensor")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindex: {int64, int32}")
|
||||
@@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
|
||||
+ ShapeHandle shape;
|
||||
+ TF_RETURN_IF_ERROR(
|
||||
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
|
||||
+ c->set_output(0, shape);
|
||||
+ return Status::OK();
|
||||
+}
|
||||
+
|
||||
Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
|
||||
int64 input_ragged_rank;
|
||||
TF_RETURN_IF_ERROR(
|
||||
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
|
||||
index 95e5602a..34372160 100644
|
||||
--- a/tensorflow/python/ops/ragged/BUILD
|
||||
+++ b/tensorflow/python/ops/ragged/BUILD
|
||||
@@ -507,6 +507,7 @@ py_test(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
+ "//tensorflow/python:tensor_array_grad",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
--
|
||||
2.27.0
|
||||
|
||||
30
CVE-2021-29516-4.patch
Normal file
30
CVE-2021-29516-4.patch
Normal file
@ -0,0 +1,30 @@
|
||||
From b055b9c474cd376259dde8779908f9eeaf097d93 Mon Sep 17 00:00:00 2001
|
||||
From: Amit Patankar <amitpatankar@google.com>
|
||||
Date: Tue, 13 Apr 2021 14:49:50 -0700
|
||||
Subject: [PATCH] Fix `tf.raw_ops.RaggedTensorToVariant` invalid resize.
|
||||
|
||||
PiperOrigin-RevId: 368299574
|
||||
Change-Id: I751c186325aa0bab397928845e790e60c2d90918
|
||||
---
|
||||
tensorflow/core/kernels/ragged_tensor_to_variant_op.cc | 5 +++++
|
||||
1 file changed, 5 insertions(+)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
index a60e5c62..fb1f25fc 100644
|
||||
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||||
@@ -165,6 +165,11 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
|
||||
// Unbatch the Ragged Tensor and encode the components.
|
||||
std::vector<RaggedTensorVariant> unbatched_ragged_input;
|
||||
+ auto batched_splits_top_vec =
|
||||
+ batched_ragged_input.splits(0).vec<SPLIT_TYPE>();
|
||||
+ int num_components = batched_splits_top_vec.size() - 1;
|
||||
+ OP_REQUIRES(context, num_components >= 0,
|
||||
+ errors::Internal("Invalid split argument."));
|
||||
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
|
||||
batched_ragged_input, &unbatched_ragged_input));
|
||||
|
||||
--
|
||||
2.27.0
|
||||
|
||||
53
CVE-2021-29551.patch
Normal file
53
CVE-2021-29551.patch
Normal file
@ -0,0 +1,53 @@
|
||||
From 480641e3599775a8895254ffbc0fc45621334f68 Mon Sep 17 00:00:00 2001
|
||||
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||||
Date: Sat, 24 Apr 2021 16:47:25 -0700
|
||||
Subject: [PATCH] Validate (and ensure validation sticks) inputs for
|
||||
`MatrixTriangularSolve`.
|
||||
|
||||
PiperOrigin-RevId: 370282444
|
||||
Change-Id: Iaed61a0b0727cc42c830658b72eb69f785f48dc5
|
||||
---
|
||||
.../matrix_triangular_solve_op_impl.h | 20 +++++++++++++++----
|
||||
1 file changed, 16 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
||||
index 99249f792b6ed..ce5392e62b9fa 100644
|
||||
--- a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
||||
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
||||
@@ -162,6 +162,9 @@ class BaseMatrixTriangularSolveOp : public OpKernel {
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
|
||||
ValidateInputTensors(ctx, in0, in1);
|
||||
+ if (!ctx->status().ok()) {
|
||||
+ return;
|
||||
+ }
|
||||
|
||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||
OP_REQUIRES(
|
||||
@@ -230,13 +233,22 @@ class MatrixTriangularSolveOp
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
+ const auto in0_num_dims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
- ctx, in0.dims() >= 2,
|
||||
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
||||
+ ctx, in0_num_dims >= 2,
|
||||
+ errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));
|
||||
|
||||
+ const auto in1_num_dims = in1.dims();
|
||||
OP_REQUIRES(
|
||||
- ctx, in1.dims() >= 2,
|
||||
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
|
||||
+ ctx, in1_num_dims >= 2,
|
||||
+ errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));
|
||||
+
|
||||
+ const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
|
||||
+ const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
|
||||
+ OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
|
||||
+ errors::InvalidArgument(
|
||||
+ "In[0] matrices in the last dimensions must be square (",
|
||||
+ in0_last_dim, " =/= ", in0_prev_dim, ")"));
|
||||
}
|
||||
};
|
||||
|
||||
29
CVE-2021-37645.patch
Normal file
29
CVE-2021-37645.patch
Normal file
@ -0,0 +1,29 @@
|
||||
From 96f364a1ca3009f98980021c4b32be5fdcca33a1 Mon Sep 17 00:00:00 2001
|
||||
From: Laura Pak <lpak@google.com>
|
||||
Date: Mon, 2 Aug 2021 13:27:01 -0700
|
||||
Subject: [PATCH] Validate axis input in tf.raw_ops.QuantizeAndDequantizeV4Grad
|
||||
|
||||
PiperOrigin-RevId: 388291385
|
||||
Change-Id: I3bab68dc61d935afa96c0da021a7b722c6dc8dc8
|
||||
---
|
||||
tensorflow/core/kernels/quantize_and_dequantize_op.cc | 7 +++++++
|
||||
1 file changed, 7 insertions(+)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||||
index 540d900f9f869..d63a49a04be62 100644
|
||||
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||||
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||||
@@ -158,6 +158,13 @@ class QuantizeAndDequantizeV4GradientOp : public OpKernel {
|
||||
Tensor* input_backprop = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, input.shape(), &input_backprop));
|
||||
+ OP_REQUIRES(
|
||||
+ ctx, axis_ >= -1,
|
||||
+ errors::InvalidArgument("Axis must be at least -1. Found ", axis_));
|
||||
+ OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
|
||||
+ errors::InvalidArgument(
|
||||
+ "Axis should be -1 or 0 or a positive value less than ",
|
||||
+ input.shape().dims(), "but given axis value was ", axis_));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, input.IsSameSize(gradient),
|
||||
88
CVE-2021-37679.patch
Normal file
88
CVE-2021-37679.patch
Normal file
@ -0,0 +1,88 @@
|
||||
From 4e2565483d0ffcadc719bd44893fb7f609bb5f12 Mon Sep 17 00:00:00 2001
|
||||
From: Edward Loper <edloper@google.com>
|
||||
Date: Thu, 29 Jul 2021 09:50:01 -0700
|
||||
Subject: [PATCH] Fix bug that could cause map_fn to produce incorrect results
|
||||
(rather than an error) when mapping over a ragged tensor with an
|
||||
inappropriate fn_output_signature. (Note: there are cases where the default
|
||||
value for fn_output_signature is not appropriate, so the user needs to
|
||||
explicitly specify the correct output signature.)
|
||||
|
||||
PiperOrigin-RevId: 387606546
|
||||
Change-Id: Ib4ea27b9634e6ab413f211cfe809a69a90f0e2cd
|
||||
---
|
||||
.../kernels/ragged_tensor_from_variant_op.cc | 16 +++++++++++++
|
||||
.../ops/ragged/ragged_map_fn_op_test.py | 23 +++++++++++++++++++
|
||||
2 files changed, 39 insertions(+)
|
||||
|
||||
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
index d9993bb6d3907..c481d90638e4e 100644
|
||||
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||||
@@ -174,7 +174,23 @@ Status NestedStackRaggedTensors(
|
||||
auto output_values_flat =
|
||||
output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int values_index = 0;
|
||||
+
|
||||
+ TensorShape expected_value_shape = component_values_shape;
|
||||
+ expected_value_shape.RemoveDim(0);
|
||||
+
|
||||
for (int i = 0; i < ragged_components.size(); i++) {
|
||||
+ // Check that the flat_values tensor shape is compatible.
|
||||
+ TensorShape value_shape = ragged_components[i].values().shape();
|
||||
+ value_shape.RemoveDim(0);
|
||||
+ if (value_shape != expected_value_shape) {
|
||||
+ return errors::InvalidArgument(
|
||||
+ "All flat_values must have compatible shapes. Shape at index 0: ",
|
||||
+ expected_value_shape, ". Shape at index ", i, ": ", value_shape,
|
||||
+ ". If you are using tf.map_fn, then you may need to specify an "
|
||||
+ "explicit fn_output_signature with appropriate ragged_rank, and/or "
|
||||
+ "convert output tensors to RaggedTensors.");
|
||||
+ }
|
||||
+
|
||||
auto component_values_flat =
|
||||
ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
|
||||
int num_inner_elements = ragged_components[i].values().NumElements();
|
||||
diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
index bead4923a0a4c..ace724ac8711d 100644
|
||||
--- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
+++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||||
@@ -21,9 +21,11 @@
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
+from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
+from tensorflow.python.ops import map_fn as map_fn_lib
|
||||
from tensorflow.python.ops import math_ops as mo
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
@@ -309,6 +311,27 @@ def testMapOnSparseTensor(self):
|
||||
)
|
||||
self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
|
||||
|
||||
+ def testRaggedMapWithIncorrectFnOutputSignature(self):
|
||||
+ x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
|
||||
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
+ 'All flat_values must have compatible shapes'):
|
||||
+ y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn(lambda y: r, r), x)
|
||||
+ self.evaluate(y)
|
||||
+
|
||||
+ def testNestedRaggedMapWithFnOutputSignature(self):
|
||||
+ ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32)
|
||||
+ ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
|
||||
+
|
||||
+ x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
|
||||
+ # pylint: disable=g-long-lambda
|
||||
+ y = map_fn_lib.map_fn(
|
||||
+ lambda r: map_fn_lib.map_fn(
|
||||
+ lambda y: r, r, fn_output_signature=ragged1d),
|
||||
+ x,
|
||||
+ fn_output_signature=ragged2d)
|
||||
+ expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]]
|
||||
+ self.assertAllEqual(y, expected)
|
||||
+
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
113
CVE-2021-37681-1.patch
Normal file
113
CVE-2021-37681-1.patch
Normal file
@ -0,0 +1,113 @@
|
||||
From 9d94482224acde044692d74107339a29f862cbac Mon Sep 17 00:00:00 2001
|
||||
From: Advait Jain <advaitjain@google.com>
|
||||
Date: Wed, 15 Jul 2020 16:20:40 -0700
|
||||
Subject: [PATCH] Change some getters to not be inline. This enables some
|
||||
|
||||
---
|
||||
tensorflow/lite/kernels/kernel_util.cc | 25 +++++++++++++
|
||||
tensorflow/lite/kernels/kernel_util.h | 49 +++++++++++---------------
|
||||
2 files changed, 46 insertions(+), 28 deletions(-)
|
||||
|
||||
diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc
|
||||
index 164aec3f..f7d7c25b 100644
|
||||
--- a/tensorflow/lite/kernels/kernel_util.cc
|
||||
+++ b/tensorflow/lite/kernels/kernel_util.cc
|
||||
@@ -27,6 +27,31 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
|
||||
namespace tflite {
|
||||
+const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
+ const TfLiteNode* node, int index) {
|
||||
+ return &context->tensors[node->inputs->data[index]];
|
||||
+}
|
||||
+
|
||||
+TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
+ int index) {
|
||||
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
|
||||
+ return (tensor->is_variable) ? tensor : nullptr;
|
||||
+}
|
||||
+
|
||||
+TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
+ int index) {
|
||||
+ return &context->tensors[node->outputs->data[index]];
|
||||
+}
|
||||
+
|
||||
+const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
+ const TfLiteNode* node, int index) {
|
||||
+ const bool use_tensor = index < node->inputs->size &&
|
||||
+ node->inputs->data[index] != kTfLiteOptionalTensor;
|
||||
+ if (use_tensor) {
|
||||
+ return &context->tensors[node->inputs->data[index]];
|
||||
+ }
|
||||
+ return nullptr;
|
||||
+}
|
||||
|
||||
// Per-axis
|
||||
TfLiteStatus PopulateConvolutionQuantizationParams(
|
||||
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
|
||||
index 59b1974c..371b712f 100644
|
||||
--- a/tensorflow/lite/kernels/kernel_util.h
|
||||
+++ b/tensorflow/lite/kernels/kernel_util.h
|
||||
@@ -24,38 +24,31 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
-inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
||||
-inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
- return t->dims->data[dim];
|
||||
-}
|
||||
-inline const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
- const TfLiteNode* node, int index) {
|
||||
- const int tensor_index = node->inputs->data[index];
|
||||
- if (tensor_index < 0) {
|
||||
- return nullptr;
|
||||
- }
|
||||
- return &context->tensors[tensor_index];
|
||||
-}
|
||||
+// A fair number of functions in this header have historically been inline.
|
||||
+// It is ok to change functions to not be inline if the latency with
|
||||
+// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
|
||||
+// made, move the newly non-inlined function declarations to the top of this
|
||||
+// header file.
|
||||
+const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
+ const TfLiteNode* node, int index);
|
||||
+
|
||||
// Note: You must check if result is not null:
|
||||
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
||||
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
||||
-inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
|
||||
- const TfLiteNode* node, int index) {
|
||||
- const int tensor_index = node->inputs->data[index];
|
||||
- if (tensor_index < 0) {
|
||||
- return nullptr;
|
||||
- }
|
||||
- TfLiteTensor* tensor = &context->tensors[tensor_index];
|
||||
- return (tensor->is_variable) ? tensor : nullptr;
|
||||
-}
|
||||
-inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
- int index) {
|
||||
- const int tensor_index = node->outputs->data[index];
|
||||
- if (tensor_index < 0) {
|
||||
- return nullptr;
|
||||
- }
|
||||
- return &context->tensors[tensor_index];
|
||||
+TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
+ int index);
|
||||
+
|
||||
+TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
+ int index);
|
||||
+
|
||||
+const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
||||
+ const TfLiteNode* node, int index);
|
||||
+
|
||||
+inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
||||
+inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||
+ return t->dims->data[dim];
|
||||
}
|
||||
+
|
||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
const int tensor_index = node->temporaries->data[index];
|
||||
--
|
||||
2.23.0
|
||||
|
||||
37
CVE-2021-37681-2.patch
Normal file
37
CVE-2021-37681-2.patch
Normal file
@ -0,0 +1,37 @@
|
||||
From 5b048e87e4e55990dae6b547add4dae59f4e1c76 Mon Sep 17 00:00:00 2001
|
||||
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||||
Date: Fri, 16 Jul 2021 09:14:31 -0700
|
||||
ubject: [PATCH] Fix a null pointer exception in SVDF
|
||||
|
||||
---
|
||||
tensorflow/lite/kernels/kernel_util.cc | 1 +
|
||||
tensorflow/lite/kernels/svdf.cc | 1 +
|
||||
2 files changed, 2 insertions(+)
|
||||
|
||||
diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc
|
||||
index f7d7c25b..f0074c09 100644
|
||||
--- a/tensorflow/lite/kernels/kernel_util.cc
|
||||
+++ b/tensorflow/lite/kernels/kernel_util.cc
|
||||
@@ -35,6 +35,7 @@ const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
|
||||
+ if (tensor == nullptr) return nullptr;
|
||||
return (tensor->is_variable) ? tensor : nullptr;
|
||||
}
|
||||
|
||||
diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc
|
||||
index ec19fb92..863c18fd 100644
|
||||
--- a/tensorflow/lite/kernels/svdf.cc
|
||||
+++ b/tensorflow/lite/kernels/svdf.cc
|
||||
@@ -281,6 +281,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
|
||||
|
||||
TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
|
||||
+ TF_LITE_ENSURE(context, state != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (weights_feature->type) {
|
||||
--
|
||||
2.23.0
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
%global _empty_manifest_terminate_build 0
|
||||
Name: tensorflow
|
||||
Version: 2.3.1
|
||||
Release: 8
|
||||
Release: 9
|
||||
Summary: An Open Source Machine Learning Framework for Everyone
|
||||
License: Apache License 2.0
|
||||
URL: https://www.tensorflow.org/
|
||||
@ -173,6 +173,18 @@ Patch0161: CVE-2021-29526-1.patch
|
||||
Patch0162: CVE-2021-29526-2.patch
|
||||
Patch0163: CVE-2021-29544-1.patch
|
||||
Patch0164: CVE-2021-29544-2.patch
|
||||
Patch0165: CVE-2020-26267-1.patch
|
||||
Patch0166: CVE-2020-26267-2.patch
|
||||
Patch0167: CVE-2021-29515.patch
|
||||
Patch0168: CVE-2021-29551.patch
|
||||
Patch0169: CVE-2021-37645.patch
|
||||
Patch0170: CVE-2021-37681-1.patch
|
||||
Patch0171: CVE-2021-37681-2.patch
|
||||
Patch0172: CVE-2021-29516-1.patch
|
||||
Patch0173: CVE-2021-29516-2.patch
|
||||
Patch0174: CVE-2021-29516-3.patch
|
||||
Patch0175: CVE-2021-29516-4.patch
|
||||
Patch0176: CVE-2021-37679.patch
|
||||
Requires: python3-future
|
||||
Requires: python3-numpy
|
||||
|
||||
@ -219,6 +231,9 @@ bazel --output_user_root=`pwd`/../output_user_root build --host_copt=-Wno-string
|
||||
%{_bindir}/*
|
||||
|
||||
%changelog
|
||||
* Mon Sep 13 2021 houyingchao <houyingchao@huawei.com> - 2.3.1-9
|
||||
- Fix CVE-2020-26267 CVE-2021-29515 CVE-2021-29551 CVE-2021-37645 CVE-2021-37681 CVE-2021-29516 CVE-2021-37679
|
||||
|
||||
* Tue Aug 31 2021 houyingchao <houyingchao@huawei.com> - 2.3.1-8
|
||||
- Fix CVE-2020-15265 CVE-2020-15266 CVE-2021-29517 CVE-2021-29518 CVE-2021-29521 CVE-2021-29526 CVE-2021-29533 CVE-2021-29537 CVE-2021-29544 CVE-2021-29560 CVE-2021-29571 CVE-2021-29583 CVE-2021-29589 CVE-2021-29595 CVE-2021-29602 CVE-2021-29604 CVE-2021-29610 CVE-2021-29611 CVE-2021-29612 CVE-2021-29614 CVE-2021-29618 CVE-2021-37635 CVE-2021-37640 CVE-2021-37642 CVE-2021-37643 CVE-2021-37651 CVE-2021-37653 CVE-2021-37654 CVE-2021-37655 CVE-2021-37657 CVE-2021-37658 CVE-2021-37661 CVE-2021-37662 CVE-2021-37664 CVE-2021-37665 CVE-2021-37666 CVE-2021-37668 CVE-2021-37669 CVE-2021-37674 CVE-2021-37675 CVE-2021-37678 CVE-2021-37683 CVE-2021-37691 CVE-2021-29513
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user