Commit 7ef7792a authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

Implement a NodeCache for cached evaulation of local-basis functions and gradients

parent 200b1e32
...@@ -21,9 +21,9 @@ addOperator(ContextTag contextTag, Expr const& expr, ...@@ -21,9 +21,9 @@ addOperator(ContextTag contextTag, Expr const& expr,
"col must be a valid treepath, or an integer/index-constant"); "col must be a valid treepath, or an integer/index-constant");
auto i = makeTreePath(row); auto i = makeTreePath(row);
auto node_i = child(this->rowBasis()->localView().tree(), i); auto node_i = child(this->rowBasis()->localView().treeCache(), i);
auto j = makeTreePath(col); auto j = makeTreePath(col);
auto node_j = child(this->colBasis()->localView().tree(), j); auto node_j = child(this->colBasis()->localView().treeCache(), j);
using LocalContext = typename ContextTag::type; using LocalContext = typename ContextTag::type;
using Tr = DefaultAssemblerTraits<LocalContext, ElementMatrix>; using Tr = DefaultAssemblerTraits<LocalContext, ElementMatrix>;
...@@ -49,8 +49,8 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView) ...@@ -49,8 +49,8 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView)
auto const& element = rowLocalView.element(); auto const& element = rowLocalView.element();
auto geometry = element.geometry(); auto geometry = element.geometry();
for_each_node(rowLocalView.tree(), [&](auto const& rowNode, auto rowTp) { for_each_node(rowLocalView.treeCache(), [&](auto const& rowNode, auto rowTp) {
for_each_node(colLocalView.tree(), [&](auto const& colNode, auto colTp) { for_each_node(colLocalView.treeCache(), [&](auto const& colNode, auto colTp) {
auto& matOp = operators_[rowTp][colTp]; auto& matOp = operators_[rowTp][colTp];
if (matOp) { if (matOp) {
matOp.bind(element, geometry); matOp.bind(element, geometry);
......
...@@ -20,7 +20,7 @@ addOperator(ContextTag contextTag, Expr const& expr, TreePath path) ...@@ -20,7 +20,7 @@ addOperator(ContextTag contextTag, Expr const& expr, TreePath path)
"path must be a valid treepath, or an integer/index-constant"); "path must be a valid treepath, or an integer/index-constant");
auto i = makeTreePath(path); auto i = makeTreePath(path);
auto node = child(this->basis()->localView().tree(), i); auto node = child(this->basis()->localView().treeCache(), i);
using LocalContext = typename ContextTag::type; using LocalContext = typename ContextTag::type;
using Tr = DefaultAssemblerTraits<LocalContext, ElementVector>; using Tr = DefaultAssemblerTraits<LocalContext, ElementVector>;
...@@ -43,7 +43,7 @@ assemble(LocalView const& localView) ...@@ -43,7 +43,7 @@ assemble(LocalView const& localView)
auto const& element = localView.element(); auto const& element = localView.element();
auto geometry = element.geometry(); auto geometry = element.geometry();
for_each_node(localView.tree(), [&](auto const& node, auto tp) { for_each_node(localView.treeCache(), [&](auto const& node, auto tp) {
auto& rhsOp = operators_[tp]; auto& rhsOp = operators_[tp];
if (rhsOp) { if (rhsOp) {
rhsOp.bind(element, geometry); rhsOp.bind(element, geometry);
......
...@@ -33,7 +33,7 @@ namespace AMDiS ...@@ -33,7 +33,7 @@ namespace AMDiS
BoundarySubset<Intersection> bs; BoundarySubset<Intersection> bs;
}; };
/// Lists of \ref DataElement on the Element, BoundaryIntersction, and InteriorIntersections /// Lists of \ref DataElement on the Element, BoundaryIntersection, and InteriorIntersections
template <class... Nodes> template <class... Nodes>
struct Data struct Data
{ {
...@@ -143,13 +143,13 @@ namespace AMDiS ...@@ -143,13 +143,13 @@ namespace AMDiS
using MatrixOperators using MatrixOperators
= TreeMatrix< = TreeMatrix<
OperatorLists<typename RowBasis::GridView,ElementMatrix>::template MatData, OperatorLists<typename RowBasis::GridView,ElementMatrix>::template MatData,
typename RowBasis::LocalView::Tree, typename RowBasis::LocalView::TreeCache,
typename ColBasis::LocalView::Tree>; typename ColBasis::LocalView::TreeCache>;
template <class Basis, class ElementVector> template <class Basis, class ElementVector>
using VectorOperators using VectorOperators
= TreeContainer< = TreeContainer<
OperatorLists<typename Basis::GridView,ElementVector>::template VecData, OperatorLists<typename Basis::GridView,ElementVector>::template VecData,
typename Basis::LocalView::Tree>; typename Basis::LocalView::TreeCache>;
} // end namespace AMDiS } // end namespace AMDiS
...@@ -4,6 +4,8 @@ install(FILES ...@@ -4,6 +4,8 @@ install(FILES
GlobalIdSet.hpp GlobalIdSet.hpp
HierarchicNodeToRangeMap.hpp HierarchicNodeToRangeMap.hpp
Interpolate.hpp Interpolate.hpp
LocalView.hpp
NodeCache.hpp
NodeIndices.hpp NodeIndices.hpp
Nodes.hpp Nodes.hpp
Order.hpp Order.hpp
......
#pragma once
#include <tuple>
#include <optional>
#include <vector>
#include <dune/common/concept.hh>
#include <dune/functions/functionspacebases/concepts.hh>
#include <amdis/functions/NodeCache.hpp>
#include <amdis/functions/Nodes.hpp>
// NOTE: This is a variant of dune-functions DefaultLocalView
namespace AMDiS
{
/// \brief The restriction of a finite element basis to a single element
template <class GB>
class LocalView
{
using PrefixPath = Dune::TypeTree::HybridTreePath<>;
// Node index set provided by PreBasis
using NodeIndexSet = NodeIndexSet_t<typename GB::PreBasis, PrefixPath>;
public:
/// The global FE basis that this is a view on
using GlobalBasis = GB;
/// The grid view the global FE basis lives on
using GridView = typename GlobalBasis::GridView;
/// Type of the grid element we are bound to
using Element = typename GridView::template Codim<0>::Entity;
/// The type used for sizes
using size_type = std::size_t;
/// Tree of local finite elements / local shape function sets
using Tree = Node_t<typename GlobalBasis::PreBasis, PrefixPath>;
/// Cached basis-tree
using TreeCache = NodeCache_t<Tree>;
/// Type used for global numbering of the basis vectors
using MultiIndex = typename NodeIndexSet::MultiIndex;
private:
template <class NIS>
using hasIndices = decltype(std::declval<NIS>().indices(std::declval<std::vector<typename NIS::MultiIndex>>().begin()));
public:
/// \brief Construct local view for a given global finite element basis
LocalView (GlobalBasis const& globalBasis)
: globalBasis_(&globalBasis)
, tree_(makeNode(globalBasis_->preBasis(), PrefixPath{}))
, treeCache_(makeNodeCache(tree_))
, nodeIndexSet_(makeNodeIndexSet(globalBasis_->preBasis(), PrefixPath{}))
{
static_assert(Dune::models<Dune::Functions::Concept::BasisTree<GridView>, Tree>());
initializeTree(tree_);
}
/// \brief Bind the view to a grid element
/**
* Having to bind the view to an element before being able to actually access any of
* its data members offers to centralize some expensive setup code in the `bind`
* method, which can save a lot of run-time.
*/
void bind (Element const& element)
{
element_ = element;
bindTree(tree_, *element_);
nodeIndexSet_.bind(tree_);
indices_.resize(size());
if constexpr (Dune::Std::is_detected<hasIndices,NodeIndexSet>{})
nodeIndexSet_.indices(indices_.begin());
else
for (size_type i = 0; i < size(); ++i)
indices_[i] = nodeIndexSet_.index(i);
}
/// \brief Return if the view is bound to a grid element
bool isBound () const
{
return bool(element_);
}
/// \brief Return the grid element that the view is bound to
Element const& element () const
{
return *element_;
}
/// \brief Unbind from the current element
/**
* Calling this method should only be a hint that the view can be unbound.
*/
void unbind ()
{
nodeIndexSet_.unbind();
element_.reset();
}
/// \brief Return the local ansatz tree associated to the bound entity
Tree const& tree () const
{
return tree_;
}
/// \brief Cached version of the local ansatz tree
TreeCache const& treeCache () const
{
return treeCache_;
}
/// \brief Total number of degrees of freedom on this element
size_type size () const
{
return tree_.size();
}
/// \brief Maximum local size for any element on the GridView
/**
* This is the maximal size needed for local matrices and local vectors.
*/
size_type maxSize () const
{
return globalBasis_->preBasis().maxNodeSize();
}
/// \brief Maps from subtree index set [0..size-1] to a globally unique multi index
/// in global basis
MultiIndex index (size_type i) const
{
return indices_[i];
}
/// \brief Return the global basis that we are a view on
GlobalBasis const& globalBasis () const
{
return *globalBasis_;
}
/// \brief Return this local-view
LocalView const& rootLocalView () const
{
return *this;
}
protected:
GlobalBasis const* globalBasis_;
std::optional<Element> element_;
Tree tree_;
TreeCache treeCache_;
NodeIndexSet nodeIndexSet_;
std::vector<MultiIndex> indices_;
};
} // end namespace AMDiS
#pragma once
#include <memory>
#include <tuple>
#include <vector>
#include <dune/common/ftraits.hh>
#include <dune/common/hash.hh>
#include <dune/common/indices.hh>
#include <dune/geometry/type.hh>
#include <dune/typetree/leafnode.hh>
#include <dune/typetree/powernode.hh>
#include <dune/typetree/compositenode.hh>
#include <amdis/common/ConcurrentCache.hpp>
namespace AMDiS
{
namespace Impl
{
template <class Node, class NodeTag = typename Node::NodeTag>
struct NodeCacheFactory;
} // end namespace Impl
/// Defines the type of a node cache associated to a given Node
template <class Node>
using NodeCache_t = typename Impl::NodeCacheFactory<Node>::type;
/// Construct a new local-basis cache from a basis-node.
template <class Node>
auto makeNodeCache (Node const& node)
{
return NodeCache_t<Node>::create(node);
}
/// Wrapper around a basis-node storing just a pointer and providing some
/// essential functionality like size, localIndex, and treeIndex
template <class NodeType>
class NodeWrapper
{
public:
using Node = NodeType;
using Element = typename Node::Element;
public:
// Store a reference to the node
NodeWrapper (Node const& node)
: node_(&node)
{}
/// Return the stored basis-node
Node const& node () const
{
assert(node_ != nullptr);
return *node_;
}
/// Return the bound grid element
Element const& element () const
{
return node_->element();
}
/// Return the index of the i-th local basis function in the index-set of the whole tree
auto localIndex (std::size_t i) const
{
return node_->localIndex(i);
}
/// Return the size of the index-set of the node
auto size () const
{
return node_->size();
}
/// Return a unique index within the tree
auto treeIndex () const
{
return node_->treeIndex();
}
protected:
Node const* node_ = nullptr;
};
/// \brief Cache of LocalBasis evaluations and jacobians at local points
/**
* Caching is done using the ConcurrentCache data structure with a key that
* depends on the element type and location of points.
* Two methods are provided for evaluation of local basis functions and local
* basis jacobians at all quadrature points. A vector of values is returned.
*
* \tparam Node Type of the leaf basis node
**/
template <class Node>
class LeafNodeCache
: public Dune::TypeTree::LeafNode
, public NodeWrapper<Node>
{
public:
using BasisNode = Node;
using FiniteElement = typename Node::FiniteElement;
using LocalBasis = typename FiniteElement::Traits::LocalBasisType;
using ShapeValues = std::vector<typename LocalBasis::Traits::RangeType>;
using ShapeGradients = std::vector<typename LocalBasis::Traits::JacobianType>;
private:
using DomainType = typename LocalBasis::Traits::DomainType;
// Pair of GeometryType and local coordinates
struct CoordKey
{
unsigned int id; // topologyId
DomainType local; // local coordinate
struct hasher
{
std::size_t operator()(CoordKey const& t) const
{
std::size_t seed = 0;
Dune::hash_combine(seed, t.id);
Dune::hash_range(seed, t.local.begin(), t.local.end());
return seed;
}
};
friend bool operator==(CoordKey const& lhs, CoordKey const& rhs)
{
return std::tie(lhs.id,lhs.local) == std::tie(rhs.id,rhs.local);
}
};
private:
// Constructor storing a reference to the passed basis-node
LeafNodeCache (Node const& basisNode)
: NodeWrapper<Node>(basisNode)
{}
public:
/// Construct a new local-basis cache
static LeafNodeCache create (Node const& basisNode)
{
return {basisNode};
}
/// Return the local finite-element of the stored basis-node.
FiniteElement const& finiteElement () const
{
return this->node_->finiteElement();
}
/// Evaluate local basis functions at local coordinates
ShapeValues const& localBasisValuesAt (DomainType const& local) const
{
CoordKey key{this->element().type().id(), local};
return shapeValues_.get(key, [&](CoordKey const&)
{
ShapeValues data;
this->localBasis().evaluateFunction(local, data);
return data;
});
}
/// Evaluate local basis jacobians at local coordinates
ShapeGradients const& localBasisJacobiansAt (DomainType const& local) const
{
CoordKey key{this->element().type().id(), local};
return shapeGradients_.get(key, [&](CoordKey const&)
{
ShapeGradients data;
this->localBasis().evaluateJacobian(local, data);
return data;
});
}
private:
/// Return the local-basis of the stored basis-node.
LocalBasis const& localBasis () const
{
return this->node_->finiteElement().localBasis();
}
private:
template <class Value>
using CoordCache = ConcurrentCache<CoordKey, Value, ConsecutivePolicy,
std::unordered_map<CoordKey, Value, typename CoordKey::hasher>>;
CoordCache<ShapeValues> shapeValues_;
CoordCache<ShapeGradients> shapeGradients_;
};
template <class Node>
class PowerNodeCache
: public Impl::NodeCacheFactory<Node>::Base
, public NodeWrapper<Node>
{
using Self = PowerNodeCache;
using Super = typename Impl::NodeCacheFactory<Node>::Base;
private:
PowerNodeCache (Node const& basisNode)
: NodeWrapper<Node>(basisNode)
{}
public:
/// Construct a new power node by setting each child individually
static auto create (Node const& basisNode)
{
Self cache{basisNode};
for (std::size_t i = 0; i < Super::degree(); ++i)
cache.setChild(i, Super::ChildType::create(basisNode.child(i)));
return cache;
}
};
template <class Node>
class CompositeNodeCache
: public Impl::NodeCacheFactory<Node>::Base
, public NodeWrapper<Node>
{
using Self = CompositeNodeCache;
using Super = typename Impl::NodeCacheFactory<Node>::Base;
private:
CompositeNodeCache (Node const& basisNode)
: NodeWrapper<Node>(basisNode)
{}
public:
/// Construct a new composite node by setting each child individually
static auto create (Node const& basisNode)
{
using TT = typename Super::ChildTypes;
Self cache{basisNode};
Dune::Hybrid::forEach(std::make_index_sequence<Super::degree()>{}, [&](auto ii) {
cache.setChild(std::tuple_element_t<ii,TT>::create(basisNode.child(ii)), ii);
});
return cache;
}
};
namespace Impl
{
template <class Node>
struct NodeCacheFactory<Node, Dune::TypeTree::LeafNodeTag>
{
using Base = Dune::TypeTree::LeafNode;
using type = LeafNodeCache<Node>;
};
template <class Node>
struct NodeCacheFactory<Node, Dune::TypeTree::PowerNodeTag>
{
using Child = typename NodeCacheFactory<typename Node::ChildType>::type;
using Base = Dune::TypeTree::PowerNode<Child, Node::degree()>;
using type = PowerNodeCache<Node>;
};
template <class Node>
struct NodeCacheFactory<Node, Dune::TypeTree::CompositeNodeTag>
{
template <class Indices> struct Childs;
template <std::size_t... i>
struct Childs<std::index_sequence<i...>> {
using type = Dune::TypeTree::CompositeNode<
typename NodeCacheFactory<typename Node::template Child<i>::type>::type...
>;
};
using Base = typename Childs<std::make_index_sequence<Node::degree()>>::type;
using type = CompositeNodeCache<Node>;
};
} // end namespace Impl
} // end namespace AMDiS
...@@ -38,7 +38,7 @@ namespace AMDiS ...@@ -38,7 +38,7 @@ namespace AMDiS
auto makeNodeIndexSet(PB const& preBasis, [[maybe_unused]] TP const& treePath) auto makeNodeIndexSet(PB const& preBasis, [[maybe_unused]] TP const& treePath)
{ {
#if DUNE_VERSION_LT(DUNE_FUNCTIONS,2,7) #if DUNE_VERSION_LT(DUNE_FUNCTIONS,2,7)
return preBasis.indexSet(treePath); return preBasis.template indexSet<TP>();
#else #else
return preBasis.makeIndexSet(); return preBasis.makeIndexSet();
#endif #endif
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <dune/functions/common/type_traits.hh> #include <dune/functions/common/type_traits.hh>
#include <dune/functions/functionspacebases/concepts.hh> #include <dune/functions/functionspacebases/concepts.hh>
#include <dune/functions/functionspacebases/defaultglobalbasis.hh> #include <dune/functions/functionspacebases/defaultglobalbasis.hh>
#include <dune/functions/functionspacebases/defaultlocalview.hh>
#include <dune/functions/functionspacebases/flatmultiindex.hh> #include <dune/functions/functionspacebases/flatmultiindex.hh>
#include <dune/grid/common/adaptcallback.hh> #include <dune/grid/common/adaptcallback.hh>
...@@ -29,6 +28,7 @@ ...@@ -29,6 +28,7 @@
#include <amdis/common/Concepts.hpp> #include <amdis/common/Concepts.hpp>
#include <amdis/common/TypeTraits.hpp> #include <amdis/common/TypeTraits.hpp>
#include <amdis/functions/FlatPreBasis.hpp> #include <amdis/functions/FlatPreBasis.hpp>
#include <amdis/functions/LocalView.hpp>
#include <amdis/linearalgebra/Traits.hpp> #include <amdis/linearalgebra/Traits.hpp>
#include <amdis/typetree/MultiIndex.hpp> #include <amdis/typetree/MultiIndex.hpp>
...@@ -71,7 +71,7 @@ namespace AMDiS ...@@ -71,7 +71,7 @@ namespace AMDiS
using Grid = typename GridView::Grid; using Grid = typename GridView::Grid;
/// Type of the local view on the restriction of the basis to a single element /// Type of the local view on the restriction of the basis to a single element
using LocalView = Dune::Functions::DefaultLocalView<Self>; using LocalView = AMDiS::LocalView<Self>;
/// Type of the communicator /// Type of the communicator
using Comm = std::conditional_t<Traits::IsFlatIndex<typename Super::MultiIndex>::value, using Comm = std::conditional_t<Traits::IsFlatIndex<typename Super::MultiIndex>::value,
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dune/typetree/childextraction.hh> #include <dune/typetree/childextraction.hh>
#include <amdis/common/Tags.hpp> #include <amdis/common/Tags.hpp>