Commit f186030f authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

make AMDiS GridFunction differentiable w.r.t. dune-functions definitions. Make...

make AMDiS GridFunction differentiable w.r.t. dune-functions definitions. Make DiscreteLocalFunction copyable
parent 7febbf22
...@@ -84,8 +84,8 @@ namespace AMDiS ...@@ -84,8 +84,8 @@ namespace AMDiS
: BiLinearForm(Dune::wrap_or_move(FWD(rowBasis))) : BiLinearForm(Dune::wrap_or_move(FWD(rowBasis)))
{} {}
std::shared_ptr<RowBasis const> const& rowBasis() const { return rowBasis_; } RowBasis const& rowBasis() const { return *rowBasis_; }
std::shared_ptr<ColBasis const> const& colBasis() const { return colBasis_; } ColBasis const& colBasis() const { return *colBasis_; }
/// \brief Associate a local operator with this BiLinearForm /// \brief Associate a local operator with this BiLinearForm
/** /**
......
...@@ -21,13 +21,13 @@ addOperator(ContextTag contextTag, Expr const& expr, ...@@ -21,13 +21,13 @@ addOperator(ContextTag contextTag, Expr const& expr,
"col must be a valid treepath, or an integer/index-constant"); "col must be a valid treepath, or an integer/index-constant");
auto i = makeTreePath(row); auto i = makeTreePath(row);
auto node_i = child(this->rowBasis()->localView().treeCache(), i); auto node_i = child(this->rowBasis().localView().treeCache(), i);
auto j = makeTreePath(col); auto j = makeTreePath(col);
auto node_j = child(this->colBasis()->localView().treeCache(), j); auto node_j = child(this->colBasis().localView().treeCache(), j);
using LocalContext = typename ContextTag::type; using LocalContext = typename ContextTag::type;
using Tr = DefaultAssemblerTraits<LocalContext, ElementMatrix>; using Tr = DefaultAssemblerTraits<LocalContext, ElementMatrix>;
auto op = makeLocalOperator<LocalContext>(expr, this->rowBasis()->gridView()); auto op = makeLocalOperator<LocalContext>(expr, this->rowBasis().gridView());
auto localAssembler = makeUniquePtr(makeAssembler<Tr>(std::move(op), node_i, node_j)); auto localAssembler = makeUniquePtr(makeAssembler<Tr>(std::move(op), node_i, node_j));
operators_[i][j].push(contextTag, std::move(localAssembler)); operators_[i][j].push(contextTag, std::move(localAssembler));
...@@ -45,7 +45,7 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView) ...@@ -45,7 +45,7 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView)
elementMatrix_.resize(rowLocalView.size(), colLocalView.size()); elementMatrix_.resize(rowLocalView.size(), colLocalView.size());
elementMatrix_ = 0; elementMatrix_ = 0;
auto const& gv = this->rowBasis()->gridView(); auto const& gv = this->rowBasis().gridView();
auto const& element = rowLocalView.element(); auto const& element = rowLocalView.element();
auto geometry = element.geometry(); auto geometry = element.geometry();
...@@ -68,11 +68,11 @@ template <class RB, class CB, class T, class Traits> ...@@ -68,11 +68,11 @@ template <class RB, class CB, class T, class Traits>
void BiLinearForm<RB,CB,T,Traits>:: void BiLinearForm<RB,CB,T,Traits>::
assemble() assemble()
{ {
auto rowLocalView = this->rowBasis()->localView(); auto rowLocalView = this->rowBasis().localView();
auto colLocalView = this->colBasis()->localView(); auto colLocalView = this->colBasis().localView();
this->init(); this->init();
for (auto const& element : elements(this->rowBasis()->gridView(), typename Traits::PartitionSet{})) { for (auto const& element : elements(this->rowBasis().gridView(), typename Traits::PartitionSet{})) {
rowLocalView.bind(element); rowLocalView.bind(element);
if (this->rowBasis() == this->colBasis()) if (this->rowBasis() == this->colBasis())
this->assemble(rowLocalView, rowLocalView); this->assemble(rowLocalView, rowLocalView);
......
...@@ -82,9 +82,9 @@ namespace AMDiS ...@@ -82,9 +82,9 @@ namespace AMDiS
{} {}
/// Return the global basis /// Return the global basis
std::shared_ptr<GlobalBasis const> const& basis() const GlobalBasis const& basis() const
{ {
return basis_; return *basis_;
} }
Coefficients const& coefficients() const Coefficients const& coefficients() const
......
...@@ -16,12 +16,12 @@ backup(std::string const& filename) ...@@ -16,12 +16,12 @@ backup(std::string const& filename)
{ {
std::ofstream out(filename, std::ios::binary); std::ofstream out(filename, std::ios::binary);
std::int64_t numElements = this->basis()->gridView().size(0); std::int64_t numElements = this->basis().gridView().size(0);
out.write((char*)&numElements, sizeof(std::int64_t)); out.write((char*)&numElements, sizeof(std::int64_t));
auto localView = this->basis()->localView(); auto localView = this->basis().localView();
std::vector<value_type> data; std::vector<value_type> data;
for (auto const& element : elements(this->basis()->gridView())) for (auto const& element : elements(this->basis().gridView()))
{ {
localView.bind(element); localView.bind(element);
this->gather(localView, data); this->gather(localView, data);
...@@ -43,13 +43,13 @@ restore(std::string const& filename) ...@@ -43,13 +43,13 @@ restore(std::string const& filename)
std::int64_t numElements = 0; std::int64_t numElements = 0;
in.read((char*)&numElements, sizeof(std::int64_t)); in.read((char*)&numElements, sizeof(std::int64_t));
assert(numElements == this->basis()->gridView().size(0)); assert(numElements == this->basis().gridView().size(0));
// assume the order of element traversal is not changed // assume the order of element traversal is not changed
auto localView = this->basis()->localView(); auto localView = this->basis().localView();
std::vector<value_type> data; std::vector<value_type> data;
this->init(sizeInfo(*this->basis()), true); this->init(sizeInfo(this->basis()), true);
for (auto const& element : elements(this->basis()->gridView())) for (auto const& element : elements(this->basis().gridView()))
{ {
std::uint64_t len = 0; std::uint64_t len = 0;
in.read((char*)&len, sizeof(std::uint64_t)); in.read((char*)&len, sizeof(std::uint64_t));
......
...@@ -55,7 +55,7 @@ namespace AMDiS ...@@ -55,7 +55,7 @@ namespace AMDiS
: LinearForm(Dune::wrap_or_move(FWD(basis))) : LinearForm(Dune::wrap_or_move(FWD(basis)))
{} {}
std::shared_ptr<GlobalBasis const> const& basis() const { return basis_; } GlobalBasis const& basis() const { return *basis_; }
/// \brief Associate a local operator with this LinearForm /// \brief Associate a local operator with this LinearForm
/** /**
......
...@@ -20,11 +20,11 @@ addOperator(ContextTag contextTag, Expr const& expr, TreePath path) ...@@ -20,11 +20,11 @@ addOperator(ContextTag contextTag, Expr const& expr, TreePath path)
"path must be a valid treepath, or an integer/index-constant"); "path must be a valid treepath, or an integer/index-constant");
auto i = makeTreePath(path); auto i = makeTreePath(path);
auto node = child(this->basis()->localView().treeCache(), i); auto node = child(this->basis().localView().treeCache(), i);
using LocalContext = typename ContextTag::type; using LocalContext = typename ContextTag::type;
using Tr = DefaultAssemblerTraits<LocalContext, ElementVector>; using Tr = DefaultAssemblerTraits<LocalContext, ElementVector>;
auto op = makeLocalOperator<LocalContext>(expr, this->basis()->gridView()); auto op = makeLocalOperator<LocalContext>(expr, this->basis().gridView());
auto localAssembler = makeUniquePtr(makeAssembler<Tr>(std::move(op), node)); auto localAssembler = makeUniquePtr(makeAssembler<Tr>(std::move(op), node));
operators_[i].push(contextTag, std::move(localAssembler)); operators_[i].push(contextTag, std::move(localAssembler));
...@@ -39,7 +39,7 @@ assemble(LocalView const& localView) ...@@ -39,7 +39,7 @@ assemble(LocalView const& localView)
elementVector_.resize(localView.size()); elementVector_.resize(localView.size());
elementVector_ = 0; elementVector_ = 0;
auto const& gv = this->basis()->gridView(); auto const& gv = this->basis().gridView();
auto const& element = localView.element(); auto const& element = localView.element();
auto geometry = element.geometry(); auto geometry = element.geometry();
...@@ -60,10 +60,10 @@ template <class GB, class T, class Traits> ...@@ -60,10 +60,10 @@ template <class GB, class T, class Traits>
void LinearForm<GB,T,Traits>:: void LinearForm<GB,T,Traits>::
assemble() assemble()
{ {
auto localView = this->basis()->localView(); auto localView = this->basis().localView();
this->init(sizeInfo(*this->basis()), true); this->init(sizeInfo(*this->basis()), true);
for (auto const& element : elements(this->basis()->gridView(), typename Traits::PartitionSet{})) { for (auto const& element : elements(this->basis().gridView(), typename Traits::PartitionSet{})) {
localView.bind(element); localView.bind(element);
this->assemble(localView); this->assemble(localView);
localView.unbind(); localView.unbind();
......
...@@ -15,6 +15,7 @@ namespace AMDiS ...@@ -15,6 +15,7 @@ namespace AMDiS
{ {
namespace tag namespace tag
{ {
struct value {};
struct jacobian {}; struct jacobian {};
struct gradient {}; struct gradient {};
struct divergence {}; struct divergence {};
...@@ -27,6 +28,12 @@ namespace AMDiS ...@@ -27,6 +28,12 @@ namespace AMDiS
template <class Sig, class Type> template <class Sig, class Type>
struct DerivativeTraits; struct DerivativeTraits;
template <class R, class D>
struct DerivativeTraits<R(D), tag::value>
{
using Range = R;
};
template <class R, class D> template <class R, class D>
struct DerivativeTraits<R(D), tag::jacobian> struct DerivativeTraits<R(D), tag::jacobian>
: public Dune::Functions::DefaultDerivativeTraits<R(D)> : public Dune::Functions::DefaultDerivativeTraits<R(D)>
......
...@@ -124,11 +124,11 @@ namespace AMDiS ...@@ -124,11 +124,11 @@ namespace AMDiS
public: public:
GlobalBasisIdSet(GlobalBasis const& globalBasis) GlobalBasisIdSet(GlobalBasis const& globalBasis)
: tree_(makeNode(globalBasis.preBasis(), TreePath{})) : tree_(globalBasis.localView().tree())
, nodeIdSet_(globalBasis.gridView()) , nodeIdSet_(globalBasis.gridView())
, twist_(globalBasis.gridView().grid().globalIdSet()) , twist_(globalBasis.gridView().grid().globalIdSet())
{ {
Dune::Functions::initializeTree(tree_); initializeTree(tree_);
} }
/// \brief Bind the IdSet to a grid element. /// \brief Bind the IdSet to a grid element.
...@@ -139,7 +139,7 @@ namespace AMDiS ...@@ -139,7 +139,7 @@ namespace AMDiS
**/ **/
void bind(Element const& element) void bind(Element const& element)
{ {
Dune::Functions::bindTree(tree_, element); bindTree(tree_, element);
nodeIdSet_.bind(tree_); nodeIdSet_.bind(tree_);
twist_.bind(element); twist_.bind(element);
data_.resize(size()); data_.resize(size());
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dune/common/concept.hh> #include <dune/common/concept.hh>
#include <dune/functions/functionspacebases/concepts.hh> #include <dune/functions/functionspacebases/concepts.hh>
#include <amdis/common/TypeTraits.hpp>
#include <amdis/functions/NodeCache.hpp> #include <amdis/functions/NodeCache.hpp>
#include <amdis/functions/Nodes.hpp> #include <amdis/functions/Nodes.hpp>
...@@ -46,8 +47,8 @@ namespace AMDiS ...@@ -46,8 +47,8 @@ namespace AMDiS
using MultiIndex = typename NodeIndexSet::MultiIndex; using MultiIndex = typename NodeIndexSet::MultiIndex;
private: private:
template <class NIS> template <class NIS, class Iter>
using hasIndices = decltype(std::declval<NIS>().indices(std::declval<std::vector<typename NIS::MultiIndex>>().begin())); using hasIndices = decltype(std::declval<NIS>().indices(std::declval<Iter>()));
public: public:
/// \brief Construct local view for a given global finite element basis /// \brief Construct local view for a given global finite element basis
...@@ -61,6 +62,30 @@ namespace AMDiS ...@@ -61,6 +62,30 @@ namespace AMDiS
initializeTree(tree_); initializeTree(tree_);
} }
/// Copy constructor.
// NOTE: needs to be implemented manually, because of the reference dependency
// between members tree_ and treeCache_
LocalView (LocalView const& other)
: globalBasis_(other.globalBasis_)
, tree_(other.tree_)
, treeCache_(makeNodeCache(tree_))
, nodeIndexSet_(other.nodeIndexSet_)
, element_(other.element_)
, indices_(other.indices_)
{}
// Move constructor
// NOTE: needs to be implemented manually, because of the reference dependency
// between members tree_ and treeCache_
LocalView (LocalView&& other)
: globalBasis_(std::move(other.globalBasis_))
, tree_(std::move(other.tree_))
, treeCache_(makeNodeCache(tree_))
, nodeIndexSet_(std::move(other.nodeIndexSet_))
, element_(std::move(other.element_))
, indices_(std::move(other.indices_))
{}
/// \brief Bind the view to a grid element /// \brief Bind the view to a grid element
/** /**
* Having to bind the view to an element before being able to actually access any of * Having to bind the view to an element before being able to actually access any of
...@@ -74,7 +99,7 @@ namespace AMDiS ...@@ -74,7 +99,7 @@ namespace AMDiS
nodeIndexSet_.bind(tree_); nodeIndexSet_.bind(tree_);
indices_.resize(size()); indices_.resize(size());
if constexpr (Dune::Std::is_detected<hasIndices,NodeIndexSet>{}) if constexpr (Dune::Std::is_detected_v<hasIndices, NodeIndexSet, TYPEOF(indices_.begin())>)
nodeIndexSet_.indices(indices_.begin()); nodeIndexSet_.indices(indices_.begin());
else else
for (size_type i = 0; i < size(); ++i) for (size_type i = 0; i < size(); ++i)
...@@ -151,10 +176,11 @@ namespace AMDiS ...@@ -151,10 +176,11 @@ namespace AMDiS
protected: protected:
GlobalBasis const* globalBasis_; GlobalBasis const* globalBasis_;
std::optional<Element> element_;
Tree tree_; Tree tree_;
TreeCache treeCache_; TreeCache treeCache_;
NodeIndexSet nodeIndexSet_; NodeIndexSet nodeIndexSet_;
std::optional<Element> element_;
std::vector<MultiIndex> indices_; std::vector<MultiIndex> indices_;
}; };
......
...@@ -17,6 +17,16 @@ namespace AMDiS ...@@ -17,6 +17,16 @@ namespace AMDiS
return lf.makeDerivative(type); return lf.makeDerivative(type);
} }
/// Implementation of local-function derivative interface of dune-functions.
/// Implements the jacobian derivative type.
template <class LocalFunction>
auto derivative(LocalFunction const& lf)
-> decltype(lf.makeDerivative(tag::gradient{}))
{
return lf.makeDerivative(tag::gradient{});
}
namespace Concepts namespace Concepts
{ {
/** \addtogroup Concepts /** \addtogroup Concepts
......
...@@ -48,9 +48,9 @@ namespace AMDiS ...@@ -48,9 +48,9 @@ namespace AMDiS
public: public:
/// Constructor. Stores a pointer to the mutable `dofvector`. /// Constructor. Stores a pointer to the mutable `dofvector`.
template <class... Path> template <class... Path>
DiscreteFunction(Coefficients& dofVector, GlobalBasis const& basis, Path... path) DiscreteFunction(Coefficients& coefficients, GlobalBasis const& basis, Path... path)
: Super(dofVector, basis, path...) : Super(coefficients, basis, path...)
, mutableCoeff_(&dofVector) , mutableCoeff_(&coefficients)
{} {}
template <class... Path> template <class... Path>
...@@ -145,6 +145,7 @@ namespace AMDiS ...@@ -145,6 +145,7 @@ namespace AMDiS
private: private:
using Coefficients = std::remove_const_t<Coeff>; using Coefficients = std::remove_const_t<Coeff>;
using GlobalBasis = GB; using GlobalBasis = GB;
using LocalView = typename GlobalBasis::LocalView;
using GridView = typename GlobalBasis::GridView; using GridView = typename GlobalBasis::GridView;
using ValueType = typename Coefficients::value_type; using ValueType = typename Coefficients::value_type;
...@@ -154,8 +155,6 @@ namespace AMDiS ...@@ -154,8 +155,6 @@ namespace AMDiS
using TreeCache = NodeCache_t<Tree>; using TreeCache = NodeCache_t<Tree>;
using SubTreeCache = typename Dune::TypeTree::ChildForTreePath<TreeCache, TreePath>; using SubTreeCache = typename Dune::TypeTree::ChildForTreePath<TreeCache, TreePath>;
using NodeToRangeEntry = Dune::Functions::DefaultNodeToRangeMap<SubTree>;
public: public:
/// Set of entities the DiscreteFunction is defined on /// Set of entities the DiscreteFunction is defined on
using EntitySet = Dune::Functions::GridViewEntitySet<GridView, 0>; using EntitySet = Dune::Functions::GridViewEntitySet<GridView, 0>;
...@@ -171,15 +170,8 @@ namespace AMDiS ...@@ -171,15 +170,8 @@ namespace AMDiS
enum { hasDerivative = false }; enum { hasDerivative = false };
public: public:
/// A LocalFunction representing the derivative of the DOFVector on a bound element /// A LocalFunction representing the localfunction and derivative of the DOFVector on a bound element
template <class Type> template <class Type>
class DerivativeLocalFunctionBase;
class GradientLocalFunction;
class PartialLocalFunction;
class DivergenceLocalFunction;
/// A LocalFunction representing the value the DOFVector on a bound element
class LocalFunction; class LocalFunction;
public: public:
...@@ -190,7 +182,6 @@ namespace AMDiS ...@@ -190,7 +182,6 @@ namespace AMDiS
, basis_(&basis) , basis_(&basis)
, treePath_(makeTreePath(path...)) , treePath_(makeTreePath(path...))
, entitySet_(basis_->gridView()) , entitySet_(basis_->gridView())
, nodeToRangeEntry_(Dune::Functions::makeDefaultNodeToRangeMap(*basis_, treePath_))
{} {}
template <class... Path> template <class... Path>
...@@ -206,14 +197,13 @@ namespace AMDiS ...@@ -206,14 +197,13 @@ namespace AMDiS
: DiscreteFunction(dofVector.coefficients(), dofVector.basis(), path...) : DiscreteFunction(dofVector.coefficients(), dofVector.basis(), path...)
{} {}
/// \brief Evaluate DiscreteFunction in global coordinates. NOTE: expensive /// \brief Evaluate DiscreteFunction in global coordinates. NOTE: expensive
Range operator()(Domain const& x) const; Range operator()(Domain const& x) const;
/// \brief Create a local function for this view on the DOFVector. \relates LocalFunction /// \brief Create a local function for this view on the DOFVector. \relates LocalFunction
LocalFunction makeLocalFunction() const LocalFunction<tag::value> makeLocalFunction() const
{ {
return LocalFunction{*this}; return {std::make_shared<LocalView>(basis().localView()), treePath(), coefficients(), tag::value{}};
} }
/// \brief Return a \ref Dune::Functions::GridViewEntitySet /// \brief Return a \ref Dune::Functions::GridViewEntitySet
...@@ -252,9 +242,9 @@ namespace AMDiS ...@@ -252,9 +242,9 @@ namespace AMDiS
GlobalBasis const* basis_; GlobalBasis const* basis_;
TreePath treePath_; TreePath treePath_;
EntitySet entitySet_; EntitySet entitySet_;
NodeToRangeEntry nodeToRangeEntry_;
}; };
// deduction guides // deduction guides
template <class Coeff, class Basis, class... Path, template <class Coeff, class Basis, class... Path,
......
...@@ -36,215 +36,136 @@ void gather(Coeff const& coeff, LocalView const& localView, LocalCoeff& localCoe ...@@ -36,215 +36,136 @@ void gather(Coeff const& coeff, LocalView const& localView, LocalCoeff& localCoe
template <class Coeff, class GB, class TP> template <class Coeff, class GB, class TP>
template <class Type>
class DiscreteFunction<Coeff const,GB,TP>::LocalFunction class DiscreteFunction<Coeff const,GB,TP>::LocalFunction
{ {
using DomainType = typename DiscreteFunction::Domain;
using RangeType = typename DiscreteFunction::Range;
template <class T>
using DerivativeRange = typename DerivativeTraits<RangeType(DomainType), T>::Range;
public: public:
using Domain = typename EntitySet::LocalCoordinate; using Domain = typename EntitySet::LocalCoordinate;
using Range = typename DiscreteFunction::Range; using Range = DerivativeRange<Type>;
enum { hasDerivative = true }; enum { hasDerivative = std::is_same<Type, tag::value>::value };
private: private:
using LocalView = typename GlobalBasis::LocalView; using LocalView = typename GlobalBasis::LocalView;
using Element = typename EntitySet::Element; using Element = typename EntitySet::Element;
using Geometry = typename Element::Geometry; using Geometry = typename Element::Geometry;
using NodeToRangeMap = Dune::Functions::DefaultNodeToRangeMap<SubTree>;
public: public:
/// Constructor. Stores a copy of the DiscreteFunction. /// Constructor. Stores a copy of the DiscreteFunction.
LocalFunction(DiscreteFunction const& globalFunction) template <class LV>
: globalFunction_(globalFunction) LocalFunction(std::shared_ptr<LV> localView, TP const& treePath, Coeff const& coefficients, Type type)
, localView_(globalFunction_.basis().localView()) : localView_(std::move(localView))
, treeCache_(makeNodeCache(localView_.tree())) , treePath_(treePath)
, subTreeCache_(&Dune::TypeTree::child(treeCache_, globalFunction_.treePath())) , coefficients_(coefficients)
{} , type_(type)
, nodeToRangeMap_(subTree())
/// Copy constructor.
LocalFunction(LocalFunction const& other)
: globalFunction_(other.globalFunction_)
, localView_(globalFunction_.basis().localView())
, treeCache_(makeNodeCache(localView_.tree()))
, subTreeCache_(&Dune::TypeTree::child(treeCache_, globalFunction_.treePath()))
{} {}
/// \brief Bind the LocalView to the element /// \brief Bind the LocalView to the element
void bind(Element const& element) void bind(Element const& element)
{ {
localView_.bind(element); localView_->bind(element);
Impl::gather(globalFunction_.coefficients(), localView_, localCoefficients_, Dune::PriorityTag<4>{}); geometry_.emplace(element.geometry());
Impl::gather(coefficients_, *localView_, localCoefficients_, Dune::PriorityTag<4>{});
bound_ = true; bound_ = true;
} }
/// \brief Unbind the LocalView from the element /// \brief Unbind the LocalView from the element
void unbind() void unbind()
{ {
localView_.unbind();