Commit e49a2c5b authored by Jan Oliver Oelerich's avatar Jan Oliver Oelerich

Improved mpi wrapper class

parent bc8586e6
......@@ -558,10 +558,7 @@ void Simulation<prec_t>::multisliceMaster(const SimulationState &st) {
// We loop here until a valid MPI request comes in. When it does, process the results and
// send out a new work package.
do {
int message_found = 0;
mpi_env.iprobe(mpi::Environment::ANY_SOURCE, mpi::Environment::MPI_RESULT, message_found, s);
if(message_found) {
if(mpi_env.iprobe(mpi::Environment::ANY_SOURCE, mpi::Environment::MPI_RESULT, s)) {
PRINT_DIAGNOSTICS(output::fmt("receiving result from %d", s.source()));
mpi_env.recv(_serialization_buffer, s.source(), s.tag());
......
......@@ -24,12 +24,13 @@
#include <mpi.h>
#include <memory>
#include <vector>
#include <iostream>
/** @file */
namespace stemsalabim {
/*!
* Utility functions and classes for MPI communication.
*/
......@@ -77,8 +78,6 @@ namespace stemsalabim {
public:
friend class Environment;
friend class Request;
Status() {
_mpi_status = std::make_shared<MPI_Status>();
}
......@@ -87,10 +86,21 @@ namespace stemsalabim {
* Return the message size
* @return message size
*/
int count() const {
int mpi_message_size;
MPI_Get_count(_mpi_status.get(), _mpi_type, &mpi_message_size);
return mpi_message_size;
template<typename T>
size_t count(const std::vector<T> & val) const {
MPI_Datatype dt = detail::mpi_type<T>();
return count(dt);
}
size_t count(const std::string & val) const {
MPI_Datatype dt = detail::mpi_type<char>();
return count(dt);
}
template<typename T>
size_t count(const T & val) const {
MPI_Datatype dt = detail::mpi_type<T>();
return count(dt);
}
/*!
......@@ -113,8 +123,11 @@ namespace stemsalabim {
/// The actual MPI_Status object is saved in a pointer.
std::shared_ptr<MPI_Status> _mpi_status;
/// The MPI Datatype
MPI_Datatype _mpi_type;
size_t count(MPI_Datatype t) const {
int mpi_message_size;
MPI_Get_count(_mpi_status.get(), t, &mpi_message_size);
return (size_t)mpi_message_size;
}
};
/*!
......@@ -129,41 +142,61 @@ namespace stemsalabim {
_request = std::make_shared<MPI_Request>();
}
Request() {
_request = std::make_shared<MPI_Request>();
}
Request& operator=(Request other) {
_request = other._request;
_mpi_type = other._mpi_type;
_valid = other._valid;
_flag = other._flag;
return *this;
}
/*!
* Test whether the request is valid already. This function is destructive,
* as soon as the request becomes valid, this shouldn't be called again.
* @return true, if request is valid.
*/
bool test() {
MPI_Test(_request.get(), &_flag, _status._mpi_status.get());
if(!_valid)
return false;
MPI_Test(_request.get(), &_flag, MPI_STATUS_IGNORE);
if(_flag != 0)
_valid = false;
return _flag != 0;
}
/*!
* Return the status object associated with the request.
* @return
* Cancel the request.
*/
const Status &status() {
_status._mpi_type = _mpi_type;
return _status;
void cancel() {
if(!_valid)
return;
MPI_Cancel(_request.get());
}
/*!
* Cancel the request.
* Set the request valid or invalid.
*/
void cancel() {
MPI_Cancel(_request.get());
void valid(bool v) {
_valid = v;
}
private:
/// The actual MPI_Status object is saved in a pointer.
std::shared_ptr<MPI_Request> _request;
/// Each Reuest also has a status
Status _status;
/// This boolean indicates, whether the request is created
/// by an MPI action.
bool _valid{false};
/// Flag that becomes true when the Request becomes valid
int _flag;
int _flag{0};
/// The MPI data type
MPI_Datatype _mpi_type;
......@@ -182,6 +215,11 @@ namespace stemsalabim {
const static int MPI_CANCEL = 3;
const static int MPI_WORK = 1;
const static int MPI_RESULT = 2;
const static int MPI_TAG_TASKS_REQUEST = 11;
const static int MPI_TAG_TASKS = 12;
const static int MPI_TAG_RESULTS_REQUEST = 13;
const static int MPI_TAG_RESULTS = 14;
/*!
* Return an instance of the mpi::Environment class. Objects should only
......@@ -352,6 +390,21 @@ namespace stemsalabim {
send(data.data(), count, destination, tag);
}
template<typename T>
Request isend(T &val, int destination, int tag) {
return isend(&val, 1, destination, tag);
}
template<typename T>
Request isend(std::vector<T> &data, int destination, int tag) {
return isend(data.data(), data.size(), destination, tag);
}
Request isend(int destination, int tag) {
int dummy;
return isend(&dummy, 0, destination, tag);
}
/*!
* non-blocking receive of a single value.
* @tparam T Type of the value
......@@ -363,7 +416,8 @@ namespace stemsalabim {
Request irecv(T &val, int source, int tag) {
MPI_Datatype dt = detail::mpi_type<T>();
Request req(dt);
MPI_Irecv(&val, 1, detail::mpi_type<T>(), source, tag, MPI_COMM_WORLD, req._request.get());
req.valid(true);
MPI_Irecv(&val, 1, dt, source, tag, MPI_COMM_WORLD, req._request.get());
return req;
}
......@@ -373,8 +427,16 @@ namespace stemsalabim {
* @param tag for which tag to probe
* @return Status object containing information.
*/
void iprobe(int source, int tag, int &flag, Status &s) {
bool iprobe(int source, int tag, Status &s) {
int flag;
MPI_Iprobe(source, tag, MPI_COMM_WORLD, &flag, s._mpi_status.get());
return flag != 0;
}
bool iprobe(Status &s) {
int flag;
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, s._mpi_status.get());
return flag != 0;
}
/*!
......@@ -392,9 +454,10 @@ namespace stemsalabim {
MPI_Datatype dt = detail::mpi_type<T>();
Request req(dt);
req.valid(true);
MPI_Irecv(data.data(),
(int) max_count,
detail::mpi_type<T>(),
dt,
source,
tag,
MPI_COMM_WORLD,
......@@ -425,8 +488,8 @@ namespace stemsalabim {
template<typename T>
Status recv(T &val, int source, int tag) {
Status status;
status._mpi_type = detail::mpi_type<T>();
MPI_Recv(&val, 1, detail::mpi_type<T>(), source, tag, MPI_COMM_WORLD, status._mpi_status.get());
MPI_Datatype dt = detail::mpi_type<T>();
MPI_Recv(&val, 1, dt, source, tag, MPI_COMM_WORLD, status._mpi_status.get());
return status;
}
......@@ -450,7 +513,6 @@ namespace stemsalabim {
int count;
Status status;
status._mpi_type = detail::mpi_type<char>();
MPI_Probe(source, tag, MPI_COMM_WORLD, status._mpi_status.get());
MPI_Get_count(status._mpi_status.get(), detail::mpi_type<char>(), &count);
......@@ -474,17 +536,17 @@ namespace stemsalabim {
Status recv(std::vector<T> &data, int source, int tag) {
int count;
MPI_Datatype dt = detail::mpi_type<T>();
Status status;
status._mpi_type = detail::mpi_type<T>();
MPI_Probe(source, tag, MPI_COMM_WORLD, status._mpi_status.get());
MPI_Get_count(status._mpi_status.get(), detail::mpi_type<T>(), &count);
MPI_Get_count(status._mpi_status.get(), dt, &count);
if((int) data.size() < count)
data.resize(count);
MPI_Recv(data.data(),
count,
detail::mpi_type<T>(),
dt,
source,
tag,
MPI_COMM_WORLD,
......@@ -502,7 +564,8 @@ namespace stemsalabim {
*/
template<typename T>
void broadcast(T *vals, std::size_t count, int root) {
MPI_Bcast(vals, (int) count, detail::mpi_type<T>(), root, MPI_COMM_WORLD);
MPI_Datatype dt = detail::mpi_type<T>();
MPI_Bcast(vals, (int) count, dt, root, MPI_COMM_WORLD);
}
/*!
......@@ -579,7 +642,17 @@ namespace stemsalabim {
*/
template<typename T>
void send(T *vals, std::size_t count, int destination, int tag) {
MPI_Send(vals, (int) count, detail::mpi_type<T>(), destination, tag, MPI_COMM_WORLD);
MPI_Datatype dt = detail::mpi_type<T>();
MPI_Send(vals, (int) count, dt, destination, tag, MPI_COMM_WORLD);
}
template<typename T>
Request isend(T *vals, std::size_t count, int destination, int tag) {
MPI_Datatype dt = detail::mpi_type<T>();
Request req(dt);
req.valid(true);
MPI_Isend(vals, (int) count, dt, destination, tag, MPI_COMM_WORLD, req._request.get());
return req;
}
/*!
......@@ -593,8 +666,8 @@ namespace stemsalabim {
template<typename T>
Status recv(T &val, int count, int source, int tag) {
Status status;
status._mpi_type = detail::mpi_type<T>();
MPI_Recv(&val, count, detail::mpi_type<T>(), source, tag, MPI_COMM_WORLD, status._mpi_status.get());
MPI_Datatype dt = detail::mpi_type<T>();
MPI_Recv(&val, count, dt, source, tag, MPI_COMM_WORLD, status._mpi_status.get());
return status;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment