Commit 2544e4e2 authored by Raman Sarokin's avatar Raman Sarokin Committed by TensorFlower Gardener

Added new attributes for 3D operations.

Reshape3DAttributes.
Slice3DAttributes.
Transpose3DAttributes.
Added methods for shape calculation for this attributes.

PiperOrigin-RevId: 312872498
Change-Id: Ia2539ad880bae0869f8d1e379d4aedad9a10095a
parent a33f9b44
......@@ -499,6 +499,14 @@ BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
}
BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) {
return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
}
BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
return BHWC(attr.appended.b + attr.prepended.b + input.b,
attr.appended.h + attr.prepended.h + input.h,
......@@ -734,5 +742,12 @@ BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) {
input.get(attr.perm.w), input.get(attr.perm.c));
}
BHWDC CalculateOutputShape(const BHWDC& input,
const Transpose3DAttributes& attr) {
return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h),
input.get(attr.perm.w), input.get(attr.perm.d),
input.get(attr.perm.c));
}
} // namespace gpu
} // namespace tflite
......@@ -399,6 +399,9 @@ struct Resize3DAttributes {
// If true, the centers of the 8 corner pixels of the input and output tensors
// are aligned, preserving the values at the corner pixels. Defaults to false.
bool align_corners = false;
// half_pixel_centers assumes pixels are of half the actual dimensions, and
// yields more accurate resizes. Only applicable to BILINEAR sampling.
bool half_pixel_centers = false;
};
float CalculateResizeScale(int32_t input_size, int32_t output_size,
......@@ -460,6 +463,20 @@ struct SliceAttributes {
// input.
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr);
// Simple slicing without advanced support for shrinking, reverse slicing etc.
struct Slice3DAttributes {
// Specifies start and end dimensions for slicing.
BHWDC starts;
BHWDC ends;
// Stride should be >= 1.
BHWDC strides;
};
// @return shape of a tensor after Slice3D operation is applied to the given
// input.
BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr);
struct AddAttributes {
TensorOrScalar param;
};
......@@ -485,6 +502,10 @@ struct ReshapeAttributes {
BHWC new_shape;
};
struct Reshape3DAttributes {
BHWDC new_shape;
};
struct TransposeAttributes {
// A permutation of the dimensions of input tensor
BHWC perm;
......@@ -494,6 +515,16 @@ struct TransposeAttributes {
// the given input.
BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr);
struct Transpose3DAttributes {
// A permutation of the dimensions of input tensor
BHWDC perm;
};
// @return shape of a tensor after Transpose3D operation is applied to
// the given input.
BHWDC CalculateOutputShape(const BHWDC& input,
const Transpose3DAttributes& attr);
struct SpaceToDepthAttributes {
int block_size;
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment