Commit 18aaa18c authored by Raman Sarokin's avatar Raman Sarokin Committed by TensorFlower Gardener

Added proper Batch dimension support in CalculateOutputShape for Concat.

PiperOrigin-RevId: 312867341
Change-Id: I089c71c5e913d089488f80a923caa81f6f156f7b
parent e9654dfb
......@@ -534,9 +534,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
switch (attr.axis) {
case Axis::CHANNELS:
for (int i = 1; i < input.size(); i++) {
if (input[i].h != new_shape.h || input[i].w != new_shape.w) {
if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Height and Width must be the same when concatenating "
"Height, Width and Batch must be the same when concatenating "
"by channels axis");
}
new_shape.c += input[i].c;
......@@ -544,9 +545,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
break;
case Axis::HEIGHT:
for (int i = 1; i < input.size(); i++) {
if (input[i].w != new_shape.w || input[i].c != new_shape.c) {
if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Channels and Width must be the same when concatenating "
"Channels, Width and Batch must be the same when concatenating "
"by height axis");
}
new_shape.h += input[i].h;
......@@ -554,14 +556,26 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
break;
case Axis::WIDTH:
for (int i = 1; i < input.size(); i++) {
if (input[i].h != new_shape.h || input[i].c != new_shape.c) {
if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Height and Channels must be the same when concatenating "
"Height, Channels and Batch must be the same when concatenating "
"by width axis");
}
new_shape.w += input[i].w;
}
break;
case Axis::BATCH:
for (int i = 1; i < input.size(); i++) {
if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
input[i].w != new_shape.w) {
return absl::InvalidArgumentError(
"Width, Height and Channels must be the same when concatenating "
"by batch axis");
}
new_shape.b += input[i].b;
}
break;
default:
return absl::InvalidArgumentError("Invalid axis");
break;
......@@ -578,9 +592,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
case Axis::CHANNELS:
for (int i = 1; i < input.size(); ++i) {
if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
input[i].d != new_shape.d) {
input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Height, Width and Depth must be the same when concatenating "
"Height, Width, Batch and Depth must be the same when "
"concatenating "
"by channels axis");
}
new_shape.c += input[i].c;
......@@ -589,9 +604,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
case Axis::HEIGHT:
for (int i = 1; i < input.size(); ++i) {
if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
input[i].d != new_shape.d) {
input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Width, Depth and Channels must be the same when concatenating "
"Width, Depth, Batch and Channels must be the same when "
"concatenating "
"by height axis");
}
new_shape.h += input[i].h;
......@@ -600,9 +616,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
case Axis::WIDTH:
for (int i = 1; i < input.size(); ++i) {
if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
input[i].d != new_shape.d) {
input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Height, Depth and Channels must be the same when concatenating "
"Height, Depth, Batch and Channels must be the same when "
"concatenating "
"by width axis");
}
new_shape.w += input[i].w;
......@@ -611,14 +628,27 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
case Axis::DEPTH:
for (int i = 1; i < input.size(); ++i) {
if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
input[i].c != new_shape.c) {
input[i].c != new_shape.c || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
"Width, Height and Channels must be the same when concatenating "
"Width, Height, Batch and Channels must be the same when "
"concatenating "
"by depth axis");
}
new_shape.d += input[i].d;
}
break;
case Axis::BATCH:
for (int i = 1; i < input.size(); ++i) {
if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
input[i].c != new_shape.c || input[i].d != new_shape.d) {
return absl::InvalidArgumentError(
"Width, Height, Depth and Channels must be the same when "
"concatenating "
"by batch axis");
}
new_shape.b += input[i].b;
}
break;
default:
return absl::InvalidArgumentError("Invalid axis");
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment