Commit a33f9b44 authored by Thomas O'Malley's avatar Thomas O'Malley Committed by TensorFlower Gardener

Reduce Layer.__call__ overhead by ~20%

This is achieved by improving the way masks are handled for inputs and outputs.
For the common case where masks are not input and are not output, minimal work
is done now.
For the masking case, the work done is about the same.

PiperOrigin-RevId: 312871996
Change-Id: I2e122551bec27d075193e1881bf236d570d25ce4
parent f4601414
......@@ -386,6 +386,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True
# Will compute masking if `compute_mask` is overridden or `supports_masking`
# is set.
self._compute_mask_overridden = (not getattr(self.compute_mask,
'_is_default', False))
@trackable.no_automatic_dependency_tracking
@generic_utils.default
def build(self, input_shape):
......@@ -844,7 +849,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
# explicitly take priority.
mask_arg_passed_by_framework = False
input_masks = self._collect_input_masks(inputs, args, kwargs)
input_masks = self._collect_input_masks(inputs, input_list, args, kwargs)
if (self._expects_mask_arg and input_masks is not None and
not self._call_arg_was_passed('mask', args, kwargs)):
mask_arg_passed_by_framework = True
......@@ -973,7 +978,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
outputs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks)
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
if hasattr(self, '_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by
# a call to self._set_inputs().
......@@ -987,7 +992,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._compute_dtype):
outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks)
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
if hasattr(self, '_set_save_spec'):
self._set_save_spec(cast_inputs)
......@@ -2259,47 +2264,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
mean_activity_loss = activity_loss / batch_size
self.add_loss(mean_activity_loss)
def _set_mask_metadata(self, inputs, outputs, previous_mask):
def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
# Many `Layer`s don't need to call `compute_mask`.
# This method is optimized to do as little work as needed for the common
# case.
if not self.supports_masking and not self._compute_mask_overridden:
return
flat_outputs = nest.flatten(outputs)
mask_already_computed = (
getattr(self, '_compute_output_and_mask_jointly', False) or
all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
# Only compute the mask if the Layer explicitly supports masking or has
# overridden `compute_mask`.
should_compute_mask = (
hasattr(self, 'compute_mask') and
(self.supports_masking or
not getattr(self.compute_mask, '_is_default', False)))
if mask_already_computed:
flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs]
elif not should_compute_mask:
flat_masks = [None for _ in flat_outputs]
else:
output_masks = self.compute_mask(inputs, previous_mask)
# `compute_mask` can return a single `None` even when a Layer
# has multiple outputs.
if output_masks is None:
flat_masks = [None for _ in flat_outputs]
else:
flat_masks = nest.flatten(output_masks)
if build_graph:
self._set_mask_keras_history_checked(flat_outputs)
return
output_masks = self.compute_mask(inputs, previous_mask)
if output_masks is None:
return
for output, mask in zip(flat_outputs, flat_masks):
flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
try:
output._keras_mask = mask
tensor._keras_mask = mask
except AttributeError:
# C Type such as np.ndarray.
pass
if tf_utils.are_all_symbolic_tensors(flat_outputs):
for output in flat_outputs:
if getattr(output, '_keras_mask', None) is not None:
# Do not track masks for `TensorFlowOpLayer` construction.
output._keras_mask._keras_history_checked = True
if build_graph:
self._set_mask_keras_history_checked(flat_outputs)
def _set_mask_keras_history_checked(self, flat_outputs):
for output in flat_outputs:
if getattr(output, '_keras_mask', None) is not None:
# Do not track masks for `TensorFlowOpLayer` construction.
output._keras_mask._keras_history_checked = True
def _collect_input_masks(self, inputs, args, kwargs):
def _collect_input_masks(self, inputs, input_list, args, kwargs):
"""Checks if `mask` argument was passed, else gathers mask from inputs."""
if self._call_arg_was_passed('mask', args, kwargs):
return self._get_call_arg_value('mask', args, kwargs)
......@@ -2307,11 +2310,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if not self._should_compute_mask:
return None
input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
inputs)
if generic_utils.is_all_none(input_masks):
input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
if all(mask is None for mask in input_masks):
return None
return input_masks
# Only do expensive `nest` operation when masking is actually being used.
return nest.pack_sequence_as(inputs, input_masks)
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
# Performance optimization: do no work in most common case.
......
......@@ -358,7 +358,8 @@ class Functional(training_lib.Model):
# by itself because it will duplicate any updates and losses in graph
# mode by `call`ing the Layers again.
output_tensors = self._run_internal_graph(inputs, mask=mask)
return nest.map_structure(lambda t: t._keras_mask, output_tensors)
return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
output_tensors)
def call(self, inputs, training=None, mask=None):
"""Calls the model on new inputs.
......
......@@ -397,7 +397,7 @@ class Sequential(functional.Functional):
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
# `outputs` will be the inputs to the next layer.
inputs = outputs
mask = outputs._keras_mask
mask = getattr(outputs, '_keras_mask', None)
return outputs
def compute_output_shape(self, input_shape):
......@@ -411,7 +411,7 @@ class Sequential(functional.Functional):
# by itself because it will duplicate any updates and losses in graph
# mode by `call`ing the Layers again.
outputs = self.call(inputs, mask=mask)
return outputs._keras_mask
return getattr(outputs, '_keras_mask', None)
@deprecated('2021-01-01', 'Please use `model.predict()` instead.')
def predict_proba(self, x, batch_size=32, verbose=0):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment