...
 
Commits (4)
......@@ -560,7 +560,8 @@ Status CapturedFunction::Instantiate(
if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
}
bool is_multi_device = metadata_->use_multi_device_function();
bool is_multi_device = false;
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device));
inst_opts.is_multi_device_function = is_multi_device;
// We infer the target device from the function library runtime.
......@@ -863,5 +864,77 @@ CapturedFunction::CapturedFunction(
: metadata_(std::move(metadata)),
captured_inputs_(std::move(captured_inputs)) {}
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
bool* is_multi_device) {
if (!metadata_->use_multi_device_function()) {
*is_multi_device = false;
return Status::OK();
}
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
Device* current_device = ctx->flr()->device();
DeviceType current_device_type(current_device->device_type());
DeviceNameUtils::ParsedName current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device->name(),
&current_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
current_device->name());
}
// Check if any of the captured inputs are placed on a device not compatible
// with the current device. For non-captured inputs, we assume they are placed
// on the current device.
for (const auto& input : captured_inputs_) {
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
DeviceNameUtils::ParsedName resource_device_name;
if (!DeviceNameUtils::ParseFullName(handle.device(),
&resource_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
handle.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
resource_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
// Check if all ops could be placed on the current device.
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
for (const auto& node : fdef->node_def()) {
// Check if the op has a kernel available for the current device.
if (!KernelDefAvailable(current_device_type, node)) {
*is_multi_device = true;
return Status::OK();
}
// If the op has a requested device, check if the requested device is
// compatible with the current device.
if (!node.device().empty()) {
DeviceNameUtils::ParsedName node_device_name;
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
node.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
node_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
}
*is_multi_device = false;
return Status::OK();
}
} // namespace data
} // namespace tensorflow
......@@ -256,6 +256,10 @@ class CapturedFunction {
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs);
// Determines whether the captured function requires the use of the
// multi-device function backend.
Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device);
const std::shared_ptr<const FunctionMetadata> metadata_;
const std::vector<Tensor> captured_inputs_;
......
......@@ -111,27 +111,41 @@ uniform 0.0379589: 907 Windsor tie 0.00735866: 466 bulletproof vest 0.00605307:
To run a model with the Hexagon Delegate, assuming we have followed the
[Hexagon Delegate Guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/hexagon_delegate.md)
and installed Hexagon libraries in `/data/local/tmp`. Run it `adb shell
"/data/local/tmp/label_image \ -m
and installed Hexagon libraries in `/data/local/tmp`. Run it wth (`-j 1`) `adb
shell \ "/data/local/tmp/label_image \ -m
/data/local/tmp/mobilenet_v1_1.0_224_quant.tflite \ -i
/data/local/tmp/grace_hopper.bmp \ -l /data/local/tmp/labels.txt -j 1"` then you
should see something like the followings: ``` Loaded model
/data/local/tmp/mobilenet_v1_1.0_224_quant.tflite resolved reporter INFO:
Initialized TensorFlow Lite runtime. INFO: Created TensorFlow Lite delegate for
Hexagon. INFO: Hexagon delegate: 31 nodes delegated out of 31 nodes.
Initialized TensorFlow Lite runtime. loaded libcdsprpc.so INFO: Created
TensorFlow Lite delegate for Hexagon. INFO: Hexagon delegate: 31 nodes delegated
out of 31 nodes with 1 partitions.
remote_handle_control available and used Applied Hexagon delegate.invoked
average time: 8.307 ms 0.729412: 653 military uniform 0.0980392: 907 Windsor tie
0.0313726: 466 bulletproof vest 0.0313726: 458 bow tie 0.0117647: 700 panpipe
```
Applied Hexagon delegate.invoked average time: 4.231 ms 0.639216: 458 bow tie
0.329412: 653 military uniform 0.00784314: 835 suit 0.00784314: 611 jersey
0.00392157: 514 cornet ```
Run the model with the XNNPACK delegate (`-x 1`), `adb shell
Run the model with the XNNPACK delegate (`-x 1`), `adb shell \
"/data/local/tmp/label_image \ -m /data/local/tmp/mobilenet_v1_1.0_224.tflite \
-i /data/local/tmp/grace_hopper.bmp \ -l /data/local/tmp/labels.txt -x 1"` then
you should see something like the followings: `Loaded model
/data/local/tmp/mobilenet_v1_1.0_224.tflite resolved reporter INFO: Initialized
TensorFlow Lite runtime. Applied XNNPACK delegate.invoked average time: 11.0237
ms 0.90707: 653 military uniform 0.0372418: 907 Windsor tie 0.0073376: 466
bulletproof vest 0.00592856: 458 bow tie 0.00414093: 514 cornet`
TensorFlow Lite runtime. Applied XNNPACK delegate.invoked average time: 17.33 ms
0.90707: 653 military uniform 0.0372418: 907 Windsor tie 0.0073376: 466
bulletproof vest 0.00592857: 458 bow tie 0.00414093: 514 cornet`
With `-h` or any other unsupported flags, `label_image` will list supported
options `sargo:/data/local/tmp $ ./label_image -h ./label_image: invalid
option -- h label_image --accelerated, -a: [0|1], use Android NNAPI or not
--old_accelerated, -d: [0|1], use old Android NNAPI delegate or not
--allow_fp16, -f: [0|1], allow running fp32 models with fp16 or not --count, -c:
loop interpreter->Invoke() for certain times --gl_backend, -g: [0|1]: use GL GPU
Delegate on Android --hexagon_delegate, -j: [0|1]: use Hexagon Delegate on
Android --input_mean, -b: input mean --input_std, -s: input standard deviation
--image, -i: image_name.bmp --labels, -l: labels for the model --tflite_model,
-m: model_name.tflite --profiling, -p: [0|1], profiling or not --num_results,
-r: number of results to show --threads, -t: number of threads --verbose, -v:
[0|1] print more information --warmup_runs, -w: number of warmup runs
--xnnpack_delegate, -x [0:1]: xnnpack delegate`
See the `label_image.cc` source code for other command line options.
......@@ -362,8 +362,8 @@ void display_usage() {
<< "--old_accelerated, -d: [0|1], use old Android NNAPI delegate or not\n"
<< "--allow_fp16, -f: [0|1], allow running fp32 models with fp16 or not\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
<< "--gl_backend, -g: use GL GPU Delegate on Android\n"
<< "--hexagon_delegate: use Hexagon Delegate on Android\n"
<< "--gl_backend, -g: [0|1]: use GL GPU Delegate on Android\n"
<< "--hexagon_delegate, -j: [0|1]: use Hexagon Delegate on Android\n"
<< "--input_mean, -b: input mean\n"
<< "--input_std, -s: input standard deviation\n"
<< "--image, -i: image_name.bmp\n"
......@@ -374,7 +374,7 @@ void display_usage() {
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "--warmup_runs, -w: number of warmup runs\n"
<< "--xnnpack_delegate, -x: xnnpack delegate\n"
<< "--xnnpack_delegate, -x [0:1]: xnnpack delegate\n"
<< "\n";
}
......