Commit a68b15fe authored by Yuanzhong Xu's avatar Yuanzhong Xu Committed by TensorFlower Gardener

[XLA] Use all-gather in SPMD to replicate a tiled tensor.

PiperOrigin-RevId: 312806463
Change-Id: If0fde80b91f1302256694554fe0cd645ad210df0
parent 63f70b56
......@@ -99,8 +99,20 @@ struct SPMDCollectiveOpsCreator {
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension)>
create_cross_partition_all_to_all;
// Function used to create a cross-partition all-gather HLO. This is optional:
// if it is nullptr, the partitioner will use all-reduce instead.
std::function<HloInstruction*(
SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, int64 all_gather_dimension)>
create_cross_partition_all_gather;
};
// Create a default SPMDCollectiveOpsCreator.
SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
int64 num_replicas);
// Logger to report memory usage during SPMD partitioning.
class SpmdLogger {
public:
......@@ -153,6 +165,15 @@ class SpmdPartitioner : public HloModulePass {
int64* next_channel_id,
SpmdLogger* logger);
// Creates all-gather based on HloSharding. Can be overridden to customize.
// The default uses a single all-gather even if there are multiple sharded
// dimensions, and adds potential reshapes and transposes to achieve that.
// If it returns false, the partitioner will fall back to all-reduce.
virtual HloInstruction* AllGatherShards(SpmdBuilder* b,
HloInstruction* operand,
const HloSharding& sharding,
int64 channel_id);
protected:
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
......@@ -160,7 +181,6 @@ class SpmdPartitioner : public HloModulePass {
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options);
private:
// Verify that the sharding of instructions in the module are valid, and also
// fill in missing sharding information.
Status PreprocessSharding(HloModule* module);
......@@ -205,6 +225,7 @@ class PartitionedHlo {
SPMDCollectiveOpsCreator collective_ops_creator;
int64* next_channel_id;
ReshardCache* reshard_cache;
SpmdPartitioner* partitioner;
};
PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
: hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
......@@ -378,6 +399,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
state.collective_ops_creator = collective_ops_creator_;
state.next_channel_id = next_channel_id_;
state.reshard_cache = &reshard_cache_;
state.partitioner = partitioner_;
return state;
}
......
......@@ -41,13 +41,19 @@ class SpmdPartitioningTest : public HloTestBase {
SpmdPartitionerOptions options;
options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
options.allow_module_signature_change = true;
auto collective_ops_creator =
GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
// Do not use all-gather for pattern-matching purpose, as the partitioner
// might create reshape/transposes around it.
collective_ops_creator.create_cross_partition_all_gather = nullptr;
TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
hlo_module, GetModuleConfigForTest()));
HloPassPipeline pass("spmd-partitioning");
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options);
pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
collective_ops_creator);
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment