Commit a33f9b44 authored by Thomas O'Malley's avatar Thomas O'Malley Committed by TensorFlower Gardener

Reduce Layer.__call__ overhead by ~20%

This is achieved by improving the way masks are handled for inputs and outputs.
For the common case where masks are not input and are not output, minimal work
is done now.
For the masking case, the work done is about the same.

PiperOrigin-RevId: 312871996
Change-Id: I2e122551bec27d075193e1881bf236d570d25ce4
parent f4601414
...@@ -386,6 +386,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -386,6 +386,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# might want to turn it off, like Sequential model. # might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True self._auto_track_sub_layers = True
# Will compute masking if `compute_mask` is overridden or `supports_masking`
# is set.
self._compute_mask_overridden = (not getattr(self.compute_mask,
'_is_default', False))
@trackable.no_automatic_dependency_tracking @trackable.no_automatic_dependency_tracking
@generic_utils.default @generic_utils.default
def build(self, input_shape): def build(self, input_shape):
...@@ -844,7 +849,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -844,7 +849,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
# explicitly take priority. # explicitly take priority.
mask_arg_passed_by_framework = False mask_arg_passed_by_framework = False
input_masks = self._collect_input_masks(inputs, args, kwargs) input_masks = self._collect_input_masks(inputs, input_list, args, kwargs)
if (self._expects_mask_arg and input_masks is not None and if (self._expects_mask_arg and input_masks is not None and
not self._call_arg_was_passed('mask', args, kwargs)): not self._call_arg_was_passed('mask', args, kwargs)):
mask_arg_passed_by_framework = True mask_arg_passed_by_framework = True
...@@ -973,7 +978,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -973,7 +978,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
outputs) outputs)
self._handle_activity_regularization(inputs, outputs) self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks) self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
if hasattr(self, '_set_inputs') and not self.inputs: if hasattr(self, '_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by # Subclassed network: explicitly set metadata normally set by
# a call to self._set_inputs(). # a call to self._set_inputs().
...@@ -987,7 +992,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -987,7 +992,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._compute_dtype): self._compute_dtype):
outputs = self.call(cast_inputs, *args, **kwargs) outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs) self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks) self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
if hasattr(self, '_set_save_spec'): if hasattr(self, '_set_save_spec'):
self._set_save_spec(cast_inputs) self._set_save_spec(cast_inputs)
...@@ -2259,47 +2264,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -2259,47 +2264,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
mean_activity_loss = activity_loss / batch_size mean_activity_loss = activity_loss / batch_size
self.add_loss(mean_activity_loss) self.add_loss(mean_activity_loss)
def _set_mask_metadata(self, inputs, outputs, previous_mask): def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
# Many `Layer`s don't need to call `compute_mask`.
# This method is optimized to do as little work as needed for the common
# case.
if not self.supports_masking and not self._compute_mask_overridden:
return
flat_outputs = nest.flatten(outputs) flat_outputs = nest.flatten(outputs)
mask_already_computed = ( mask_already_computed = (
getattr(self, '_compute_output_and_mask_jointly', False) or getattr(self, '_compute_output_and_mask_jointly', False) or
all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
# Only compute the mask if the Layer explicitly supports masking or has
# overridden `compute_mask`.
should_compute_mask = (
hasattr(self, 'compute_mask') and
(self.supports_masking or
not getattr(self.compute_mask, '_is_default', False)))
if mask_already_computed: if mask_already_computed:
flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs] if build_graph:
elif not should_compute_mask: self._set_mask_keras_history_checked(flat_outputs)
flat_masks = [None for _ in flat_outputs] return
else:
output_masks = self.compute_mask(inputs, previous_mask) output_masks = self.compute_mask(inputs, previous_mask)
# `compute_mask` can return a single `None` even when a Layer if output_masks is None:
# has multiple outputs. return
if output_masks is None:
flat_masks = [None for _ in flat_outputs]
else:
flat_masks = nest.flatten(output_masks)
for output, mask in zip(flat_outputs, flat_masks): flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
try: try:
output._keras_mask = mask tensor._keras_mask = mask
except AttributeError: except AttributeError:
# C Type such as np.ndarray. # C Type such as np.ndarray.
pass pass
if tf_utils.are_all_symbolic_tensors(flat_outputs): if build_graph:
for output in flat_outputs: self._set_mask_keras_history_checked(flat_outputs)
if getattr(output, '_keras_mask', None) is not None:
# Do not track masks for `TensorFlowOpLayer` construction. def _set_mask_keras_history_checked(self, flat_outputs):
output._keras_mask._keras_history_checked = True for output in flat_outputs:
if getattr(output, '_keras_mask', None) is not None:
# Do not track masks for `TensorFlowOpLayer` construction.
output._keras_mask._keras_history_checked = True
def _collect_input_masks(self, inputs, args, kwargs): def _collect_input_masks(self, inputs, input_list, args, kwargs):
"""Checks if `mask` argument was passed, else gathers mask from inputs.""" """Checks if `mask` argument was passed, else gathers mask from inputs."""
if self._call_arg_was_passed('mask', args, kwargs): if self._call_arg_was_passed('mask', args, kwargs):
return self._get_call_arg_value('mask', args, kwargs) return self._get_call_arg_value('mask', args, kwargs)
...@@ -2307,11 +2310,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ...@@ -2307,11 +2310,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if not self._should_compute_mask: if not self._should_compute_mask:
return None return None
input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None), input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
inputs) if all(mask is None for mask in input_masks):
if generic_utils.is_all_none(input_masks):
return None return None
return input_masks
# Only do expensive `nest` operation when masking is actually being used.
return nest.pack_sequence_as(inputs, input_masks)
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
# Performance optimization: do no work in most common case. # Performance optimization: do no work in most common case.
......
...@@ -358,7 +358,8 @@ class Functional(training_lib.Model): ...@@ -358,7 +358,8 @@ class Functional(training_lib.Model):
# by itself because it will duplicate any updates and losses in graph # by itself because it will duplicate any updates and losses in graph
# mode by `call`ing the Layers again. # mode by `call`ing the Layers again.
output_tensors = self._run_internal_graph(inputs, mask=mask) output_tensors = self._run_internal_graph(inputs, mask=mask)
return nest.map_structure(lambda t: t._keras_mask, output_tensors) return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
output_tensors)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""Calls the model on new inputs. """Calls the model on new inputs.
......
...@@ -397,7 +397,7 @@ class Sequential(functional.Functional): ...@@ -397,7 +397,7 @@ class Sequential(functional.Functional):
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
# `outputs` will be the inputs to the next layer. # `outputs` will be the inputs to the next layer.
inputs = outputs inputs = outputs
mask = outputs._keras_mask mask = getattr(outputs, '_keras_mask', None)
return outputs return outputs
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
...@@ -411,7 +411,7 @@ class Sequential(functional.Functional): ...@@ -411,7 +411,7 @@ class Sequential(functional.Functional):
# by itself because it will duplicate any updates and losses in graph # by itself because it will duplicate any updates and losses in graph
# mode by `call`ing the Layers again. # mode by `call`ing the Layers again.
outputs = self.call(inputs, mask=mask) outputs = self.call(inputs, mask=mask)
return outputs._keras_mask return getattr(outputs, '_keras_mask', None)
@deprecated('2021-01-01', 'Please use `model.predict()` instead.') @deprecated('2021-01-01', 'Please use `model.predict()` instead.')
def predict_proba(self, x, batch_size=32, verbose=0): def predict_proba(self, x, batch_size=32, verbose=0):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment