Skip to content

Commit

Permalink
cross-chain warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
yiz committed Feb 24, 2020
1 parent b69492b commit a7c61ac
Show file tree
Hide file tree
Showing 23 changed files with 1,746 additions and 426 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "lib/stan_math"]
path = lib/stan_math
url = https://github.com/stan-dev/math.git
branch = mpi_warmup_v2
5 changes: 5 additions & 0 deletions make/mpi_warmup.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ifdef MPI_ADAPTED_WARMUP
CXXFLAGS += -DSTAN_LANG_MPI -DMPI_ADAPTED_WARMUP
CC=mpicxx
CXX=mpicxx
endif
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ help:

-include $(HOME)/.config/stan/make.local # user-defined variables
-include make/local # user-defined variables
-include make/mpi_warmup.mk

MATH ?= lib/stan_math/
ifeq ($(OS),Windows_NT)
Expand Down
138 changes: 138 additions & 0 deletions src/stan/callbacks/mpi_stream_writer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#ifndef STAN_CALLBACKS_MPI_STREAM_WRITER_HPP
#define STAN_CALLBACKS_MPI_STREAM_WRITER_HPP

#ifdef MPI_ADAPTED_WARMUP

#include <stan/callbacks/writer.hpp>
#include <stan/math/mpi/envionment.hpp>
#include <ostream>
#include <vector>
#include <string>

namespace stan {
namespace callbacks {
/**
* <code>mpi_stream_writer</code> is an implementation
* of <code>writer</code> that writes to a stream.
*/
class mpi_stream_writer : public writer {
public:
/**
* Constructs a stream writer with an output stream
* and an optional prefix for comments.
*
* @param[in, out] output stream to write
* @param[in] comment_prefix string to stream before
* each comment line. Default is "".
*/
mpi_stream_writer(int num_chains, std::ostream& output,
const std::string& comment_prefix = "")
: num_chains_(num_chains), output_(output),
comment_prefix_(comment_prefix)
{}

/**
* Virtual destructor
*/
virtual ~mpi_stream_writer() {}

/**
* Set new value for @c num_chains_.
*
* @param[in] n new value of @c num_chains_
*/
void set_num_chains(int n) {
num_chains_ = n;
}

/**
* Writes a set of names on a single line in csv format followed
* by a newline.
*
* Note: the names are not escaped.
*
* @param[in] names Names in a std::vector
*/
void operator()(const std::vector<std::string>& names) {
write_vector(names);
}

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] state Values in a std::vector
*/
void operator()(const std::vector<double>& state) {
write_vector(state);
}

/**
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() {
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
output_ << comment_prefix_ << std::endl;
}
}

/**
* Writes the comment_prefix then the message followed by a newline.
*
* @param[in] message A string
*/
void operator()(const std::string& message) {
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
output_ << comment_prefix_ << message << std::endl;
}
}

private:

/**
* nb. of chains that have its own output stream
*/
int num_chains_;

/**
* Output stream
*/
std::ostream& output_;

/**
* Comment prefix to use when printing comments: strings and blank lines
*/
std::string comment_prefix_;

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] v Values in a std::vector
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (stan::math::mpi::Session::is_in_inter_chain_comm(num_chains_)) {
if (v.empty()) return;

typename std::vector<T>::const_iterator last = v.end();
--last;

for (typename std::vector<T>::const_iterator it = v.begin();
it != last; ++it)
output_ << *it << ",";
output_ << v.back() << std::endl;
}
}
};

}
}

#endif

#endif
173 changes: 88 additions & 85 deletions src/stan/callbacks/stream_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,101 +7,104 @@
#include <string>

namespace stan {
namespace callbacks {
namespace callbacks {

/**
* <code>stream_writer</code> is an implementation
* of <code>writer</code> that writes to a stream.
*/
class stream_writer : public writer {
public:
/**
* Constructs a stream writer with an output stream
* and an optional prefix for comments.
*
* @param[in, out] output stream to write
* @param[in] comment_prefix string to stream before
* each comment line. Default is "".
*/
explicit stream_writer(std::ostream& output,
const std::string& comment_prefix = "")
: output_(output), comment_prefix_(comment_prefix) {}
/**
* <code>stream_writer</code> is an implementation
* of <code>writer</code> that writes to a stream.
*/
class stream_writer : public writer {
public:
/**
* Constructs a stream writer with an output stream
* and an optional prefix for comments.
*
* @param[in, out] output stream to write
* @param[in] comment_prefix string to stream before
* each comment line. Default is "".
*/
stream_writer(std::ostream& output,
const std::string& comment_prefix = ""):
output_(output), comment_prefix_(comment_prefix) {}

/**
* Virtual destructor
*/
virtual ~stream_writer() {}
/**
* Virtual destructor
*/
virtual ~stream_writer() {}

/**
* Writes a set of names on a single line in csv format followed
* by a newline.
*
* Note: the names are not escaped.
*
* @param[in] names Names in a std::vector
*/
void operator()(const std::vector<std::string>& names) {
write_vector(names);
}
/**
* Writes a set of names on a single line in csv format followed
* by a newline.
*
* Note: the names are not escaped.
*
* @param[in] names Names in a std::vector
*/
void operator()(const std::vector<std::string>& names) {
write_vector(names);
}

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] state Values in a std::vector
*/
void operator()(const std::vector<double>& state) { write_vector(state); }
/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] state Values in a std::vector
*/
void operator()(const std::vector<double>& state) {
write_vector(state);
}

/**
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() { output_ << comment_prefix_ << std::endl; }
/**
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() {
output_ << comment_prefix_ << std::endl;
}

/**
* Writes the comment_prefix then the message followed by a newline.
*
* @param[in] message A string
*/
void operator()(const std::string& message) {
output_ << comment_prefix_ << message << std::endl;
}
/**
* Writes the comment_prefix then the message followed by a newline.
*
* @param[in] message A string
*/
void operator()(const std::string& message) {
output_ << comment_prefix_ << message << std::endl;
}

private:
/**
* Output stream
*/
std::ostream& output_;
private:
/**
* Output stream
*/
std::ostream& output_;

/**
* Comment prefix to use when printing comments: strings and blank lines
*/
std::string comment_prefix_;
/**
* Comment prefix to use when printing comments: strings and blank lines
*/
std::string comment_prefix_;

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] v Values in a std::vector
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (v.empty())
return;
/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] v Values in a std::vector
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (v.empty()) return;

typename std::vector<T>::const_iterator last = v.end();
--last;
typename std::vector<T>::const_iterator last = v.end();
--last;

for (typename std::vector<T>::const_iterator it = v.begin(); it != last;
++it)
output_ << *it << ",";
output_ << v.back() << std::endl;
}
};
for (typename std::vector<T>::const_iterator it = v.begin();
it != last; ++it)
output_ << *it << ",";
output_ << v.back() << std::endl;
}
};

} // namespace callbacks
} // namespace stan
}
}
#endif
Loading

0 comments on commit a7c61ac

Please sign in to comment.