Commit 2c05ca21 authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

add runner to LinearSolver and make it accessible, allow to derive from...

add runner to LinearSolver and make it accessible, allow to derive from PetscRunner for user-defined petsc solver implementation
parent 52603493
...@@ -43,20 +43,41 @@ namespace AMDiS ...@@ -43,20 +43,41 @@ namespace AMDiS
using ElementMatrix = FlatMatrix<CoefficientType>; using ElementMatrix = FlatMatrix<CoefficientType>;
public: public:
/// Constructor. Wraps the reference into a non-destroying shared_ptr or moves the basis into /// Constructor. Stores the row and column basis in a local `shared_ptr` to const
/// a new shared_ptr.
BiLinearForm(std::shared_ptr<RB> const& rowBasis, std::shared_ptr<CB> const& colBasis) BiLinearForm(std::shared_ptr<RB> const& rowBasis, std::shared_ptr<CB> const& colBasis)
: Super(*rowBasis, *colBasis) : Super(*rowBasis, *colBasis)
, rowBasis_(rowBasis) , rowBasis_(rowBasis)
, colBasis_(colBasis) , colBasis_(colBasis)
{ {
operators_.init(*rowBasis, *colBasis); operators_.init(*rowBasis_, *colBasis_);
auto const rowSize = rowBasis->localView().maxSize(); auto const rowSize = rowBasis_->localView().maxSize();
auto const colSize = colBasis->localView().maxSize(); auto const colSize = colBasis_->localView().maxSize();
elementMatrix_.resize(rowSize, colSize); elementMatrix_.resize(rowSize, colSize);
} }
/// Wraps the passed global bases into (non-destroying) shared_ptr
template <class RB_, class CB_,
REQUIRES(Concepts::Similar<RB_,RB>),
REQUIRES(Concepts::Similar<CB_,CB>)>
BiLinearForm(RB_&& rowBasis, CB_&& colBasis)
: BiLinearForm(Dune::wrap_or_move(FWD(rowBasis)), Dune::wrap_or_move(FWD(colBasis)))
{}
/// Constructor for rowBasis == colBasis
template <class RB_ = RB, class CB_ = CB,
REQUIRES(std::is_same<RB_,CB_>::value)>
explicit BiLinearForm(std::shared_ptr<RB> const& rowBasis)
: BiLinearForm(rowBasis, rowBasis)
{}
/// Wraps the passed row-basis into a (non-destroying) shared_ptr
template <class RB_,
REQUIRES(Concepts::Similar<RB_,RB>)>
explicit BiLinearForm(RB_&& rowBasis)
: BiLinearForm(Dune::wrap_or_move(FWD(rowBasis)))
{}
std::shared_ptr<RowBasis const> const& rowBasis() const { return rowBasis_; } std::shared_ptr<RowBasis const> const& rowBasis() const { return rowBasis_; }
std::shared_ptr<ColBasis const> const& colBasis() const { return colBasis_; } std::shared_ptr<ColBasis const> const& colBasis() const { return colBasis_; }
...@@ -108,7 +129,6 @@ namespace AMDiS ...@@ -108,7 +129,6 @@ namespace AMDiS
void assemble(SymmetryStructure symmetry = SymmetryStructure::unknown); void assemble(SymmetryStructure symmetry = SymmetryStructure::unknown);
protected: protected:
/// Dense matrix to store coefficients during \ref assemble() /// Dense matrix to store coefficients during \ref assemble()
ElementMatrix elementMatrix_; ElementMatrix elementMatrix_;
...@@ -124,15 +144,26 @@ namespace AMDiS ...@@ -124,15 +144,26 @@ namespace AMDiS
template <class RB, class CB> template <class RB, class CB>
BiLinearForm(RB&&, CB&&) BiLinearForm(RB&&, CB&&)
-> BiLinearForm<Underlying_t<RB>, Underlying_t<CB>>; -> BiLinearForm<Underlying_t<RB>, Underlying_t<CB>>;
template <class RB>
BiLinearForm(RB&&)
-> BiLinearForm<Underlying_t<RB>, Underlying_t<RB>>;
#endif #endif
template <class RB, class CB, class... Args> template <class T = double, class RB, class CB>
auto makeBiLinearForm(RB&& rowBasis, CB&& colBasis) auto makeBiLinearForm(RB&& rowBasis, CB&& colBasis)
{ {
using BLF = BiLinearForm<Underlying_t<RB>, Underlying_t<CB>>; using BLF = BiLinearForm<Underlying_t<RB>, Underlying_t<CB>, T>;
return BLF{FWD(rowBasis), FWD(colBasis)}; return BLF{FWD(rowBasis), FWD(colBasis)};
} }
template <class T = double, class RB>
auto makeBiLinearForm(RB&& rowBasis)
{
using BLF = BiLinearForm<Underlying_t<RB>, Underlying_t<RB>, T>;
return BLF{FWD(rowBasis)};
}
} // end namespace AMDiS } // end namespace AMDiS
#include <amdis/BiLinearForm.inc.hpp> #include <amdis/BiLinearForm.inc.hpp>
...@@ -49,6 +49,13 @@ namespace AMDiS ...@@ -49,6 +49,13 @@ namespace AMDiS
elementVector_.resize(localSize); elementVector_.resize(localSize);
} }
/// Wraps the passed global basis into a (non-destroying) shared_ptr
template <class GB_,
REQUIRES(Concepts::Similar<GB_,GB>)>
explicit LinearForm(GB_&& basis)
: LinearForm(Dune::wrap_or_move(FWD(basis)))
{}
std::shared_ptr<GlobalBasis const> const& basis() const { return basis_; } std::shared_ptr<GlobalBasis const> const& basis() const { return basis_; }
/// \brief Associate a local operator with this LinearForm /// \brief Associate a local operator with this LinearForm
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <dune/functions/functionspacebases/powerbasis.hh> #include <dune/functions/functionspacebases/powerbasis.hh>
#include <dune/grid/yaspgrid.hh> #include <dune/grid/yaspgrid.hh>
#include <amdis/AdaptiveGrid.hpp>
#include <amdis/common/Logical.hpp> #include <amdis/common/Logical.hpp>
#include <amdis/common/TypeTraits.hpp> #include <amdis/common/TypeTraits.hpp>
#include <amdis/functions/ParallelGlobalBasis.hpp> #include <amdis/functions/ParallelGlobalBasis.hpp>
...@@ -70,7 +71,7 @@ namespace AMDiS ...@@ -70,7 +71,7 @@ namespace AMDiS
template <class HostGrid, class PreBasisCreator, class T = double> template <class HostGrid, class PreBasisCreator, class T = double>
struct DefaultBasisCreator struct DefaultBasisCreator
{ {
using Grid = AdaptiveGrid<HostGrid>; using Grid = AdaptiveGrid_t<HostGrid>;
using GridView = typename Grid::LeafGridView; using GridView = typename Grid::LeafGridView;
static auto create(std::string const& name, GridView const& gridView) static auto create(std::string const& name, GridView const& gridView)
......
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <functional>
#include <utility>
namespace AMDiS namespace AMDiS
{ {
template <class InputIter, class T, class Func> /// \brief Split a sequence `[first,last)` by the separators `sep` and pass the tokens as
void split(InputIter first, InputIter end, T const& t, Func f) /// begin-end iterator pair to the provided functor `f = void(InputIterator, InputIterator)`
template <class InputIter, class Tp, class BinaryFunc>
void split(InputIter first, InputIter last, Tp sep, BinaryFunc f)
{ {
if (first == end) if (first == last)
return; return;
while (true) { while (true) {
InputIter found = std::find(first, end, t); InputIter found = std::find(first, last, sep);
f(first, found); f(first, found);
if (found == end) if (found == last)
break; break;
first = ++found; first = ++found;
} }
} }
template <class InputIter, class SeparaterIter, class Func> /// \brief Split a sequence `[first,last)` by any of the separators `[s_first, s_last)` and pass the tokens as
void split(InputIter first, InputIter end, SeparaterIter s_first, SeparaterIter s_end, Func f) /// begin-end iterator pair to the provided functor `f = void(InputIterator, InputIterator)`
template <class InputIter, class SeparaterIter, class BinaryFunc>
void split(InputIter first, InputIter last, SeparaterIter s_first, SeparaterIter s_last, BinaryFunc f)
{ {
if (first == end) if (first == last)
return; return;
while (true) { while (true) {
InputIter found = std::find_first_of(first, end, s_first, s_end); InputIter found = std::find_first_of(first, last, s_first, s_last);
f(first, found); f(first, found);
if (found == end) if (found == last)
break; break;
first = ++found; first = ++found;
} }
} }
/// \brief Output the cumulative sum of one range to a second range
// NOTE: backport of std::exclusive_scan from c++17
template <class InputIter, class OutputIter, class Tp, class BinaryOperation>
OutputIter exclusive_scan(InputIter first, InputIter last,
OutputIter result, Tp init, BinaryOperation binary_op)
{
while (first != last) {
auto v = init;
init = binary_op(init, *first);
++first;
*result++ = std::move(v);
}
return result;
}
/// \brief Output the cumulative sum of one range to a second range
// NOTE: backport of std::exclusive_scan from c++17
template <class InputIter, class OutputIter, class Tp>
inline OutputIter exclusive_scan(InputIter first, InputIter last,
OutputIter result, Tp init)
{
return AMDiS::exclusive_scan(first, last, result, std::move(init), std::plus<>());
}
} }
\ No newline at end of file
...@@ -44,13 +44,17 @@ namespace AMDiS ...@@ -44,13 +44,17 @@ namespace AMDiS
public: public:
/// Constructor /// Constructor
explicit LinearSolver(std::string prefix) template <class... Args>
: runner_(prefix) explicit LinearSolver(std::string prefix, Args&&... args)
: runner_(prefix, FWD(args)...)
{} {}
Runner& runner() { return runner_; }
Runner const& runner() const { return runner_; }
private: private:
/// Implements \ref LinearSolverInterface::solveSystemImpl() /// Implements \ref LinearSolverInterface::solveSystemImpl()
void solveImpl(Mat const& A, Vec& x, Vec const& b, Comm& comm, SolverInfo& solverInfo) override void solveImpl(Mat const& A, Vec& x, Vec const& b, Comm const& comm, SolverInfo& solverInfo) override
{ {
Dune::Timer t; Dune::Timer t;
if (solverInfo.doCreateMatrixData()) { if (solverInfo.doCreateMatrixData()) {
......
...@@ -38,14 +38,14 @@ namespace AMDiS ...@@ -38,14 +38,14 @@ namespace AMDiS
* \p x A [block-]vector for the unknown components. * \p x A [block-]vector for the unknown components.
* \p b A [block-]vector for the right-hand side of the linear system. * \p b A [block-]vector for the right-hand side of the linear system.
**/ **/
void solve(Mat const& A, Vec& x, Vec const& b, Comm& comm, SolverInfo& solverInfo) void solve(Mat const& A, Vec& x, Vec const& b, Comm const& comm, SolverInfo& solverInfo)
{ {
solveImpl(A, x, b, comm, solverInfo); solveImpl(A, x, b, comm, solverInfo);
} }
private: private:
/// main methods that all solvers must implement /// main methods that all solvers must implement
virtual void solveImpl(Mat const& A, Vec& x, Vec const& b, Comm& comm, SolverInfo& solverInfo) = 0; virtual void solveImpl(Mat const& A, Vec& x, Vec const& b, Comm const& comm, SolverInfo& solverInfo) = 0;
}; };
} // end namespace AMDiS } // end namespace AMDiS
...@@ -20,7 +20,7 @@ namespace AMDiS ...@@ -20,7 +20,7 @@ namespace AMDiS
virtual ~RunnerInterface() = default; virtual ~RunnerInterface() = default;
/// Is called at the beginning of a solution procedure /// Is called at the beginning of a solution procedure
virtual void init(Mat const& A, Comm& comm) = 0; virtual void init(Mat const& A, Comm const& comm) = 0;
/// Is called at the end of a solution procedure /// Is called at the end of a solution procedure
virtual void exit() = 0; virtual void exit() = 0;
......
...@@ -35,7 +35,7 @@ namespace AMDiS ...@@ -35,7 +35,7 @@ namespace AMDiS
} }
/// Implements \ref RunnerInterface::init() /// Implements \ref RunnerInterface::init()
void init(Mat const& A, Comm& comm) override void init(Mat const& A, Comm const& comm) override
{ {
DUNE_UNUSED_PARAMETER(comm); DUNE_UNUSED_PARAMETER(comm);
...@@ -49,9 +49,11 @@ namespace AMDiS ...@@ -49,9 +49,11 @@ namespace AMDiS
"Error in solver.compute(matrix)"); "Error in solver.compute(matrix)");
} }
/// Implements \ref RunnerInterface::exit() /// Implements \ref RunnerInterface::exit()
void exit() override {} void exit() override
{
initialized_ = false;
}
/// Implements \ref RunnerInterface::solve() /// Implements \ref RunnerInterface::solve()
int solve(Mat const& A, Vec& x, Vec const& b, SolverInfo& solverInfo) override int solve(Mat const& A, Vec& x, Vec const& b, SolverInfo& solverInfo) override
......
...@@ -30,7 +30,7 @@ namespace AMDiS ...@@ -30,7 +30,7 @@ namespace AMDiS
/// Implements \ref RunnerInterface::init() /// Implements \ref RunnerInterface::init()
void init(Mat const& A, Comm& comm) override void init(Mat const& A, Comm const& comm) override
{ {
DUNE_UNUSED_PARAMETER(comm); DUNE_UNUSED_PARAMETER(comm);
...@@ -45,7 +45,10 @@ namespace AMDiS ...@@ -45,7 +45,10 @@ namespace AMDiS
} }
/// Implements \ref RunnerInterface::exit() /// Implements \ref RunnerInterface::exit()
void exit() override {} void exit() override
{
initialized_ = false;
}
/// Implements \ref RunnerInterface::solve() /// Implements \ref RunnerInterface::solve()
int solve(Mat const& A, Vec& x, Vec const& b, SolverInfo& solverInfo) override int solve(Mat const& A, Vec& x, Vec const& b, SolverInfo& solverInfo) override
......
...@@ -32,7 +32,7 @@ namespace AMDiS ...@@ -32,7 +32,7 @@ namespace AMDiS
} }
/// Implements \ref RunnerInterface::init() /// Implements \ref RunnerInterface::init()
void init(Mat const& mat, typename Traits::Comm& comm) override void init(Mat const& mat, typename Traits::Comm const& comm) override
{ {
solver_ = solverCreator_->create(mat, comm); solver_ = solverCreator_->create(mat, comm);
} }
......
...@@ -41,13 +41,14 @@ namespace AMDiS ...@@ -41,13 +41,14 @@ namespace AMDiS
Parameters::get(prefix_ + "->relative tolerance", rTol_); Parameters::get(prefix_ + "->relative tolerance", rTol_);
Parameters::get(prefix_ + "->max iteration", maxIter_); Parameters::get(prefix_ + "->max iteration", maxIter_);
Parameters::get(prefix_ + "->print cycle", printCycle_); Parameters::get(prefix_ + "->print cycle", printCycle_);
initPrecon(prefix_);
} }
/// Implementation of \ref RunnerInterface::init() /// Implementation of \ref RunnerInterface::init()
void init(Matrix const& A, Comm& comm) override void init(Matrix const& A, Comm const& comm) override
{ {
DUNE_UNUSED_PARAMETER(comm); DUNE_UNUSED_PARAMETER(comm);
initPrecon(prefix_);
P_->init(A); P_->init(A);
} }
......
...@@ -45,7 +45,7 @@ namespace AMDiS ...@@ -45,7 +45,7 @@ namespace AMDiS
} }
/// Implementation of \ref RunnerInterface::init() /// Implementation of \ref RunnerInterface::init()
void init(Matrix const& matrix, Comm&) override void init(Matrix const& matrix, Comm const&) override
{ {
try { try {
if (bool(solver_) && storeSymbolic_) if (bool(solver_) && storeSymbolic_)
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
namespace AMDiS namespace AMDiS
{ {
template <class> class Constraints;
/// \brief The basic container that stores a base matrix /// \brief The basic container that stores a base matrix
template <class DofMap> template <class DofMap>
class PetscMatrix class PetscMatrix
{ {
friend struct Constraints<PetscMatrix<DofMap>>; template <class> friend class Constraints;
public: public:
/// The matrix type of the underlying base matrix /// The matrix type of the underlying base matrix
......
...@@ -68,7 +68,7 @@ namespace AMDiS ...@@ -68,7 +68,7 @@ namespace AMDiS
} }
/// Implements \ref RunnerInterface::init() /// Implements \ref RunnerInterface::init()
void init(typename Traits::Mat const& mat, typename Traits::Comm& comm) override void init(typename Traits::Mat const& mat, typename Traits::Comm const& comm) override
{ {
exit(); exit();
#if HAVE_MPI #if HAVE_MPI
...@@ -110,24 +110,27 @@ namespace AMDiS ...@@ -110,24 +110,27 @@ namespace AMDiS
protected: protected:
// initialize the KSP solver from the initfile // initialize the KSP solver from the initfile
void initKSP(KSP ksp, std::string prefix) const virtual void initKSP(KSP ksp, std::string prefix) const
{ {
// see https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/KSP/KSPType.html // see https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/KSP/KSPType.html
auto kspTypeStr = Parameters::get<std::string>(prefix); auto kspType = Parameters::get<std::string>(prefix + "->ksp");
if (kspTypeStr) { std::string kspTypeStr = kspType.value_or("default");
if (kspTypeStr.value() == "direct")
if (!kspType)
Parameters::get(prefix, kspTypeStr);
if (kspTypeStr == "direct")
initDirectSolver(ksp, prefix); initDirectSolver(ksp, prefix);
else if (kspTypeStr.value() != "default") { else if (kspTypeStr != "default") {
KSPSetType(ksp, kspTypeStr.value().c_str()); KSPSetType(ksp, kspTypeStr.c_str());
// initialize some KSP specific parameters // initialize some KSP specific parameters
initKSPParameters(ksp, kspTypeStr.value().c_str(), prefix); initKSPParameters(ksp, kspTypeStr.c_str(), prefix);
// set initial guess to nonzero only for non-preonly ksp type // set initial guess to nonzero only for non-preonly ksp type
if (kspTypeStr.value() != "preonly") if (kspTypeStr != "preonly")
KSPSetInitialGuessNonzero(ksp, PETSC_TRUE); KSPSetInitialGuessNonzero(ksp, PETSC_TRUE);
} }
}
// set a KSPMonitor if info > 0 // set a KSPMonitor if info > 0
int info = 0; int info = 0;
...@@ -167,7 +170,7 @@ namespace AMDiS ...@@ -167,7 +170,7 @@ namespace AMDiS
// initialize a direct solver from the initfile // initialize a direct solver from the initfile
void initDirectSolver(KSP ksp, std::string prefix) const virtual void initDirectSolver(KSP ksp, std::string prefix) const
{ {
KSPSetInitialGuessNonzero(ksp, PETSC_TRUE); KSPSetInitialGuessNonzero(ksp, PETSC_TRUE);
KSPSetType(ksp, KSPRICHARDSON); KSPSetType(ksp, KSPRICHARDSON);
...@@ -192,7 +195,7 @@ namespace AMDiS ...@@ -192,7 +195,7 @@ namespace AMDiS
// initialize the preconditioner pc from the initfile // initialize the preconditioner pc from the initfile
void initPC(PC pc, std::string prefix) const virtual void initPC(PC pc, std::string prefix) const
{ {
// see https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCType.html // see https://www.mcs.anl.gov/petsc/petsc-current/docs/manualpages/PC/PCType.html
auto pcType = Parameters::get<std::string>(prefix); auto pcType = Parameters::get<std::string>(prefix);
...@@ -234,7 +237,7 @@ namespace AMDiS ...@@ -234,7 +237,7 @@ namespace AMDiS
// provide initfile parameters for some PETSc KSP parameters // provide initfile parameters for some PETSc KSP parameters
void initKSPParameters(KSP ksp, char const* ksptype, std::string prefix) const virtual void initKSPParameters(KSP ksp, char const* ksptype, std::string prefix) const
{ {
// parameters for the Richardson solver // parameters for the Richardson solver
if (std::strcmp(ksptype, KSPRICHARDSON) == 0) if (std::strcmp(ksptype, KSPRICHARDSON) == 0)
...@@ -288,7 +291,7 @@ namespace AMDiS ...@@ -288,7 +291,7 @@ namespace AMDiS
} }
} }
private: protected:
std::string prefix_; std::string prefix_;
int info_ = 0; int info_ = 0;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment