...
 
Commits (2)
......@@ -290,7 +290,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
end_logical_time_ = end_time;
// Find the earliest time we're allowed to start prefetching.
for (current_logical_prefetch_time_ = start_time;
current_logical_prefetch_time_ <= end_logical_time_ &&
current_logical_prefetch_time_ < end_logical_time_ &&
max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ <
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
end_logical_time_);
......@@ -305,9 +305,9 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() {
}
bool CostAnalysisPrefetchIntervalPicker::Done() const {
// The end time is inclusive, so we're done if the prefetch time is greater
// than that.
if (current_logical_prefetch_time_ > end_logical_time_) {
// The end time is exclusive, so we're done if the prefetch time is greater
// than or equal to the end time.
if (current_logical_prefetch_time_ >= end_logical_time_) {
return true;
}
float logical_interval_elapsed = GetLogicalIntervalElapsed(
......@@ -1473,6 +1473,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
: "alternate")
<< " memory between " << start_time << " and "
<< copy_done_schedule_before_time << " keeping until " << end_time;
CHECK_LT(start_time, copy_done_schedule_before_time);
allocations->push_back(
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
......@@ -1760,6 +1761,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
alternate_mem_interval.size = request.size;
while (!options_.prefetch_interval_picker->Done()) {
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
VLOG(4) << "Trying alternate memory allocation ("
<< alternate_mem_interval.start << ", " << request.end_time << ")";
// If this additional asynchronous copy would violate the limit, try a
......
......@@ -912,7 +912,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Build layer if applicable (if the `build` method has been
# overridden).
self._maybe_build(inputs)
cast_inputs = self._maybe_cast_inputs(inputs)
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
if not self.dynamic:
# Wrapping `call` function in autograph to allow for dynamic control
......@@ -982,7 +982,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Eager execution on data tensors.
with backend.name_scope(self._name_scope()):
self._maybe_build(inputs)
cast_inputs = self._maybe_cast_inputs(inputs)
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
with base_layer_utils.autocast_context_manager(
self._compute_dtype):
outputs = self.call(cast_inputs, *args, **kwargs)
......@@ -2117,7 +2117,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
"""
return self._dtype_policy.compute_dtype
def _maybe_cast_inputs(self, inputs):
def _maybe_cast_inputs(self, inputs, input_list):
"""Maybe casts the inputs to the compute dtype.
If self._compute_dtype is floating-point, and self_autocast is True,
......@@ -2125,32 +2125,38 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Args:
inputs: Input tensor, or structure of input tensors.
input_list: Flat list of input tensors.
Returns:
`inputs`, but tensors may have been casted to self._compute_dtype
"""
compute_dtype = self._compute_dtype
if (self._autocast and compute_dtype and
dtypes.as_dtype(compute_dtype).is_floating):
def f(x):
"""Cast a single Tensor or TensorSpec to the compute dtype."""
cast_types = (ops.Tensor, sparse_tensor.SparseTensor,
ragged_tensor.RaggedTensor)
if (isinstance(x, cast_types) and x.dtype.is_floating and
x.dtype.base_dtype.name != compute_dtype):
if self._dtype_defaulted_to_floatx:
self._warn_about_input_casting(x.dtype.base_dtype)
return math_ops.cast(x, compute_dtype)
elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating:
# Inputs may be TensorSpecs when this function is called from
# model._set_inputs.
return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name)
else:
return x
return nest.map_structure(f, inputs)
should_autocast = (
self._autocast and compute_dtype and
dtypes.as_dtype(compute_dtype).is_floating)
if (should_autocast and
any(self._should_cast_single_input(x) for x in input_list)):
# Only perform expensive `nest` operation when needed.
return nest.map_structure(self._cast_single_input, inputs)
else:
return inputs
def _should_cast_single_input(self, x):
cast_types = (ops.Tensor, sparse_tensor.SparseTensor,
ragged_tensor.RaggedTensor)
return (isinstance(x, cast_types) and x.dtype.is_floating and
x.dtype.base_dtype.name != self._compute_dtype)
def _cast_single_input(self, x):
"""Cast a single Tensor or TensorSpec to the compute dtype."""
if self._should_cast_single_input(x):
if self._dtype_defaulted_to_floatx:
self._warn_about_input_casting(x.dtype.base_dtype)
return math_ops.cast(x, self._compute_dtype)
else:
return x
def _warn_about_input_casting(self, input_dtype):
# self._already_warned_about_input_casting is only retrieved or set in this
# function.
......