From b0a7d9670efe8fe62d594ac5a0cc2858bc5f0705 Mon Sep 17 00:00:00 2001 From: Simon Praetorius Date: Sat, 21 Nov 2020 14:48:12 +0100 Subject: [PATCH] rename for_each_node into Traversal::forEachNode --- amdis/BiLinearForm.inc.hpp | 4 +- amdis/DataTransfer.inc.hpp | 12 +- amdis/LinearForm.inc.hpp | 2 +- amdis/PeriodicBC.inc.hpp | 2 +- amdis/ProblemStat.inc.hpp | 6 +- amdis/functions/Interpolate.hpp | 2 +- .../DiscreteLocalFunction.inc.hpp | 6 +- amdis/typetree/Traversal.hpp | 228 +++++++++--------- docs/reference/GlobalBasis.md | 8 +- examples/traversal.cc | 17 +- test/DataTransferTest.hpp | 4 +- test/TreeContainerTest.cpp | 6 +- 12 files changed, 156 insertions(+), 141 deletions(-) diff --git a/amdis/BiLinearForm.inc.hpp b/amdis/BiLinearForm.inc.hpp index 9f34e328..dca943f8 100644 --- a/amdis/BiLinearForm.inc.hpp +++ b/amdis/BiLinearForm.inc.hpp @@ -49,8 +49,8 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView) auto const& element = rowLocalView.element(); auto geometry = element.geometry(); - for_each_node(rowLocalView.treeCache(), [&](auto const& rowNode, auto rowTp) { - for_each_node(colLocalView.treeCache(), [&](auto const& colNode, auto colTp) { + Traversal::forEachNode(rowLocalView.treeCache(), [&](auto const& rowNode, auto rowTp) { + Traversal::forEachNode(colLocalView.treeCache(), [&](auto const& colNode, auto colTp) { auto& matOp = operators_[rowTp][colTp]; if (matOp) { matOp.bind(element, geometry); diff --git a/amdis/DataTransfer.inc.hpp b/amdis/DataTransfer.inc.hpp index db0dab3d..94919374 100644 --- a/amdis/DataTransfer.inc.hpp +++ b/amdis/DataTransfer.inc.hpp @@ -58,7 +58,7 @@ preAdapt(C const& coeff, bool mightCoarsen) LocalView lv = basis_->localView(); auto const& idSet = gv.grid().localIdSet(); - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].preAdaptInit(lv, coeff, node); }); @@ -70,7 +70,7 @@ preAdapt(C const& coeff, bool mightCoarsen) lv.bind(e); auto& treeContainer = it.first->second; - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].cacheLocal(treeContainer[tp]); }); } @@ -130,7 +130,7 @@ preAdapt(C const& coeff, bool mightCoarsen) }; restrictLocalCompleted = true; - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { restrictLocalCompleted &= nodeDataTransfer_[tp].restrictLocal(father, treeContainer[tp], xInChildCached, childContainer[tp], init); @@ -158,7 +158,7 @@ void DataTransfer::adapt(C& coeff) GridView gv = basis_->gridView(); LocalView lv = basis_->localView(); auto const& idSet = gv.grid().localIdSet(); - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].adaptInit(lv, coeff, node); }); @@ -176,7 +176,7 @@ void DataTransfer::adapt(C& coeff) if (it != persistentContainer_.end()) { lv.bind(e); auto const& treeContainer = it->second; - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].copyLocal(treeContainer[tp]); }); finished[index] = true; @@ -211,7 +211,7 @@ void DataTransfer::adapt(C& coeff) return fatherGeo.local(childGeo.global(x)); }; - for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { + Traversal::forEachLeafNode(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].prolongLocal(father, treeContainer[tp], xInFather, init); }); diff --git a/amdis/LinearForm.inc.hpp b/amdis/LinearForm.inc.hpp index 47641d95..639903eb 100644 --- a/amdis/LinearForm.inc.hpp +++ b/amdis/LinearForm.inc.hpp @@ -43,7 +43,7 @@ assemble(LocalView const& localView) auto const& element = localView.element(); auto geometry = element.geometry(); - for_each_node(localView.treeCache(), [&](auto const& node, auto tp) { + Traversal::forEachNode(localView.treeCache(), [&](auto const& node, auto tp) { auto& rhsOp = operators_[tp]; if (rhsOp) { rhsOp.bind(element, geometry); diff --git a/amdis/PeriodicBC.inc.hpp b/amdis/PeriodicBC.inc.hpp index bd748950..84be1441 100644 --- a/amdis/PeriodicBC.inc.hpp +++ b/amdis/PeriodicBC.inc.hpp @@ -197,7 +197,7 @@ auto PeriodicBC:: coords(Node const& tree, std::vector const& localIndices) const { std::vector dofCoords(localIndices.size()); - for_each_leaf_node(tree, [&](auto const& node, auto&&) + Traversal::forEachLeafNode(tree, [&](auto const& node, auto&&) { std::size_t size = node.finiteElement().size(); auto geometry = node.element().geometry(); diff --git a/amdis/ProblemStat.inc.hpp b/amdis/ProblemStat.inc.hpp index 659ec71d..f474b624 100644 --- a/amdis/ProblemStat.inc.hpp +++ b/amdis/ProblemStat.inc.hpp @@ -213,7 +213,7 @@ void ProblemStat::createMatricesAndVectors() rhs_ = std::make_shared(globalBasis_); auto localView = globalBasis_->localView(); - for_each_node(localView.tree(), [&,this](auto&&, auto treePath) -> void + Traversal::forEachNode(localView.tree(), [&,this](auto&&, auto treePath) -> void { std::string i = to_string(treePath); estimates_[i].resize(globalBasis_->gridView().indexSet().size(0)); @@ -241,7 +241,7 @@ void ProblemStat::createMarker() { marker_.clear(); auto localView = globalBasis_->localView(); - for_each_node(localView.tree(), [&,this](auto&&, auto treePath) -> void + Traversal::forEachNode(localView.tree(), [&,this](auto&&, auto treePath) -> void { std::string componentName = name_ + "->marker[" + to_string(treePath) + "]"; @@ -264,7 +264,7 @@ void ProblemStat::createFileWriter() filewriter_.clear(); auto localView = globalBasis_->localView(); - for_each_node(localView.tree(), [&](auto const& /*node*/, auto treePath) -> void + Traversal::forEachNode(localView.tree(), [&](auto const& /*node*/, auto treePath) -> void { std::string componentName = name_ + "->output[" + to_string(treePath) + "]"; auto format = Parameters::get>(componentName + "->format"); diff --git a/amdis/functions/Interpolate.hpp b/amdis/functions/Interpolate.hpp index 9a147767..221abd21 100644 --- a/amdis/functions/Interpolate.hpp +++ b/amdis/functions/Interpolate.hpp @@ -54,7 +54,7 @@ namespace AMDiS lf.bind(e); auto&& subTree = Dune::TypeTree::child(localView.tree(),treePath); - for_each_leaf_node(subTree, [&](auto const& node, auto const& tp) + Traversal::forEachLeafNode(subTree, [&](auto const& node, auto const& tp) { using Traits = typename TYPEOF(node)::FiniteElement::Traits::LocalBasisType::Traits; using RangeField = typename Traits::RangeFieldType; diff --git a/amdis/gridfunctions/DiscreteLocalFunction.inc.hpp b/amdis/gridfunctions/DiscreteLocalFunction.inc.hpp index aad97bc7..3155551c 100644 --- a/amdis/gridfunctions/DiscreteLocalFunction.inc.hpp +++ b/amdis/gridfunctions/DiscreteLocalFunction.inc.hpp @@ -179,7 +179,7 @@ auto DiscreteFunction::LocalFunction ::evaluateFunction(LocalCoordinate const& local) const { Range y(0); - for_each_leaf_node(this->subTreeCache(), [&](auto const& node, auto const& tp) + Traversal::forEachLeafNode(this->subTreeCache(), [&](auto const& node, auto const& tp) { auto const& shapeFunctionValues = node.localBasisValuesAt(local); std::size_t size = node.finiteElement().size(); @@ -216,7 +216,7 @@ auto DiscreteFunction::LocalFunction ::evaluateJacobian(LocalCoordinate const& local) const { Range dy(0); - for_each_leaf_node(this->subTreeCache(), [&](auto const& node, auto const& tp) + Traversal::forEachLeafNode(this->subTreeCache(), [&](auto const& node, auto const& tp) { LocalToGlobalBasisAdapter localBasis(node, this->geometry()); auto const& gradients = localBasis.gradientsAt(local); @@ -284,7 +284,7 @@ auto DiscreteFunction::LocalFunction std::size_t comp = this->type_.comp; Range dy(0); - for_each_leaf_node(this->subTreeCache(), [&](auto const& node, auto const& tp) + Traversal::forEachLeafNode(this->subTreeCache(), [&](auto const& node, auto const& tp) { LocalToGlobalBasisAdapter localBasis(node, this->geometry()); auto const& partial = localBasis.partialsAt(local, comp); diff --git a/amdis/typetree/Traversal.hpp b/amdis/typetree/Traversal.hpp index e41a3a6a..52089544 100644 --- a/amdis/typetree/Traversal.hpp +++ b/amdis/typetree/Traversal.hpp @@ -8,126 +8,130 @@ #include #include -// NOTE: backport of dune/typetree/traversal.hpp from Dune 2.7 - -namespace AMDiS -{ - namespace Impl - { - /// \brief Helper function that returns the degree of a Tree. - /** - * The return type is either `size_t` if it is a dynamic tree or the flag `dynamic` - * is set to `true`, or as the type returned by the `degree()` member function of the - * tree. - * - * This function allows to change the tree traversal from static to dynamic in case - * of power nodes and uses static traversal for composite and dynamic traversal for - * all dynamic nodes. - **/ - template - auto traversalDegree(Tree const& tree) - { - if constexpr (dynamic && Tree::isPower) - return std::size_t(tree.degree()); - else if constexpr (Tree::isPower || Tree::isComposite) - return std::integral_constant{}; - else - return tree.degree(); - } - - /** - * Traverse tree and visit each node. The signature is the same - * as for the public for_each_node function in Dune::Typtree, - * despite the additionally passed treePath argument. The path - * passed here is associated to the tree and the relative - * paths of the children (wrt. to tree) are appended to this. - * Hence the behavior of the public function is resembled - * by passing an empty treePath. - **/ - template - void for_each_node(Tree&& tree, TreePath treePath, PreFunc&& preFunc, LeafFunc&& leafFunc, PostFunc&& postFunc) - { - using TreeType = std::decay_t; - if constexpr (TreeType::isLeaf) { - // If we have a leaf tree just visit it using the leaf function. - leafFunc(tree, treePath); - } else { - // Otherwise visit the tree with the pre function, - // visit all children using a static or dynamic loop, and - // finally visit the tree with the post function. - preFunc(tree, treePath); - Dune::Hybrid::forEach(Dune::range(traversalDegree(tree)), [&](auto i) { - auto childTreePath = push_back(treePath, i); - Impl::for_each_node(tree.child(i), childTreePath, preFunc, leafFunc, postFunc); - }); - postFunc(tree, treePath); - } - } - - } // end namespace Impl +// NOTE: backport of dune/typetree/traversal.hh from Dune 2.7 +namespace AMDiS { +namespace Traversal { +namespace Impl_ { - /// \brief Traverse tree and visit each node + /// \brief Helper function that returns the degree of a Tree. /** - * All passed callback functions are called with the - * node and corresponding treepath as arguments. + * The return type is either `size_t` if it is a dynamic tree or the flag `dynamic` + * is set to `true`, or as the type returned by the `degree()` member function of the + * tree. * - * \param tree The tree to traverse - * \param preFunc This function is called for each inner node before visiting its children - * \param leafFunc This function is called for each leaf node - * \param postFunc This function is called for each inner node after visiting its children - */ - template - void for_each_node(Tree&& tree, PreFunc&& preFunc, LeafFunc&& leafFunc, PostFunc&& postFunc) + * This function allows to change the tree traversal from static to dynamic in case + * of power nodes and uses static traversal for composite and dynamic traversal for + * all dynamic nodes. + **/ + template + auto traversalDegree(Tree const& tree) { - auto root = Dune::TypeTree::hybridTreePath(); - Impl::for_each_node(tree, root, preFunc, leafFunc, postFunc); + if constexpr (dynamic && Tree::isPower) + return std::size_t(tree.degree()); + else if constexpr (Tree::isPower || Tree::isComposite) + return std::integral_constant{}; + else + return tree.degree(); } - /// \brief Traverse tree and visit each node - /** - * All passed callback functions are called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param innerFunc This function is called for each inner node before visiting its children - * \param leafFunc This function is called for each leaf node - */ - template - void for_each_node(Tree&& tree, InnerFunc&& innerFunc, LeafFunc&& leafFunc) - { - auto root = Dune::TypeTree::hybridTreePath(); - Impl::for_each_node(tree, root, innerFunc, leafFunc, NoOp{}); - } +} // end namespace Impl_ - /// \brief Traverse tree and visit each node - /** - * The passed callback function is called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param nodeFunc This function is called for each node - */ - template - void for_each_node(Tree&& tree, NodeFunc&& nodeFunc) - { - auto root = Dune::TypeTree::hybridTreePath(); - Impl::for_each_node(tree, root, nodeFunc, nodeFunc, NoOp{}); - } - /// \brief Traverse tree and visit each leaf node - /** - * The passed callback function is called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param leafFunc This function is called for each leaf node - */ - template - void for_each_leaf_node(Tree&& tree, LeafFunc&& leafFunc) - { - auto root = Dune::TypeTree::hybridTreePath(); - Impl::for_each_node(tree, root, NoOp{}, leafFunc, NoOp{}); +/** + * Traverse tree and visit each node. On each node call preFunc and postFunc before and after + * visiting its childs. Call leafFunc on all leaf nodes. + **/ +template +void forEachNode(Tree&& tree, TreePath treePath, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc) +{ + using TreeType = std::decay_t; + if constexpr (TreeType::isLeaf) { + // For leaf nodes just visit using the leaf function. + leafFunc(tree, treePath); + } else { + preFunc(tree, treePath); + const auto degree = Impl_::traversalDegree(tree); + if constexpr (std::is_integral_v) { + // Specialization for dynamic traversal + for (std::size_t i = 0; i < std::size_t(degree); ++i) { + auto childTreePath = Dune::TypeTree::push_back(treePath, i); + forEachNode(tree.child(i), childTreePath, preFunc, leafFunc, postFunc); + } + } else { + // Specialization for static traversal + const auto indices = std::make_index_sequence{}; + Ranges::forIndices(indices, [&](auto i) { + auto childTreePath = Dune::TypeTree::push_back(treePath, i); + forEachNode(tree.child(i), childTreePath, preFunc, leafFunc, postFunc); + }); + } + postFunc(tree, treePath); } +} + + +/// \brief Traverse tree and visit each node +/** + * All passed callback functions are called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param preFunc This function is called for each inner node before visiting its children + * \param leafFunc This function is called for each leaf node + * \param postFunc This function is called for each inner node after visiting its children + */ +template +void forEachNode(Tree&& tree, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc) +{ + auto root = Dune::TypeTree::hybridTreePath(); + forEachNode(tree, root, preFunc, leafFunc, postFunc); +} + +/// \brief Traverse tree and visit each node +/** + * All passed callback functions are called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param innerFunc This function is called for each inner node before visiting its children + * \param leafFunc This function is called for each leaf node + */ +template +void forEachNode(Tree&& tree, Inner&& innerFunc, Leaf&& leafFunc) +{ + auto root = Dune::TypeTree::hybridTreePath(); + forEachNode(tree, root, innerFunc, leafFunc, NoOp{}); +} + +/// \brief Traverse tree and visit each node +/** + * The passed callback function is called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param nodeFunc This function is called for each node + */ +template +void forEachNode(Tree&& tree, NodeFunc&& nodeFunc) +{ + auto root = Dune::TypeTree::hybridTreePath(); + forEachNode(tree, root, nodeFunc, nodeFunc, NoOp{}); +} + +/// \brief Traverse tree and visit each leaf node +/** + * The passed callback function is called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param leafFunc This function is called for each leaf node + */ +template +void forEachLeafNode(Tree&& tree, Leaf&& leafFunc) +{ + auto root = Dune::TypeTree::hybridTreePath(); + forEachNode(tree, root, NoOp{}, leafFunc, NoOp{}); +} -} // end namespace AMDiS +}} // end namespace AMDiS::Traversal diff --git a/docs/reference/GlobalBasis.md b/docs/reference/GlobalBasis.md index 5136f74f..5d815b95 100644 --- a/docs/reference/GlobalBasis.md +++ b/docs/reference/GlobalBasis.md @@ -110,12 +110,12 @@ for (const auto& e : elements(basis.gridView())) // loop over all grid elements A bound LocalView has the method `LocalView::index(size_type)` mapping a local index to a global index. In other words it maps a local basis function defined on an element to its corresponding global basis function. We can use that to build a global stiffness matrix from local contributions on a single element and then insert those into a single matrix in global indices. Another method is `LocalView::tree()` that returns the root node of the local basis tree. The main method all nodes share is `Node::localIndex(size_type)` which maps a leaf node index to the local index within the local basis tree. -#### The for_each_node and for_each_leaf_node helper functions -Quite often we want to perform operations on certain nodes of the tree other than the root node. This can be useful if we want to work with the actual implementations wich are usually leaf nodes. For this we can use the helper functions `for_each_node` and `for_each_leaf_node` defined in `amdis/typetree/Traversal.hpp`. Those functions traverse the tree and call the given function on every (leaf) node with the node and a type of tree index we shall explain later as arguments. we show the usage with the following example using the Taylor-Hood-Basis defined above. Here we assume to have a `LocalView` `localView` that is bound to an element. +#### The Traversal::forEachNode and Traversal::forEachLeafNode helper functions +Quite often we want to perform operations on certain nodes of the tree other than the root node. This can be useful if we want to work with the actual implementations wich are usually leaf nodes. For this we can use the helper functions `Traversal::forEachNode` and `Traversal::forEachLeafNode` defined in `amdis/typetree/Traversal.hpp`. Those functions traverse the tree and call the given function on every (leaf) node with the node and a type of tree index we shall explain later as arguments. we show the usage with the following example using the Taylor-Hood-Basis defined above. Here we assume to have a `LocalView` `localView` that is bound to an element. ```c++ auto root = localView.tree(); -for_each_leaf_node(root, [&](auto const& node, auto const& tp) { +Traversal::forEachLeafNode(root, [&](auto const& node, auto const& tp) { // do something on node }); ``` @@ -146,7 +146,7 @@ We shall show the usage of the local finite element class handed out by the func ```c++ auto root = localView.tree(); -for_each_leaf_node(root, [&](auto const& node, auto const& tp) { +Traversal::forEachLeafNode(root, [&](auto const& node, auto const& tp) { // Extract some types from the node using Node = Underlying_t; using LocalFunction = typename Node::FiniteElement::Traits::LocalInterpolationType::FunctionType; diff --git a/examples/traversal.cc b/examples/traversal.cc index e8753951..d2f41bcf 100644 --- a/examples/traversal.cc +++ b/examples/traversal.cc @@ -19,14 +19,25 @@ int main(int argc, char** argv) // create basis using namespace Dune::Functions::BasisFactory; auto basis1 = makeBasis(gridView, - composite(power<2>(lagrange<2>(), flatInterleaved()), lagrange<1>(), flatLexicographic())); + power<10>(composite(power<10>(lagrange<2>()), power<10>(lagrange<1>()))) ); + auto basis2 = makeBasis(gridView, + power<2>(power<2>(power<2>(power<2>(power<2>(power<2>(power<2>(lagrange<1>()))))))) ); + + auto basis3 = makeBasis(gridView, + composite(power<2>(lagrange<2>(), flatInterleaved()), lagrange<1>(), flatLexicographic())); + + auto basis4 = makeBasis(gridView, power<2>(power<2>(lagrange<2>(), flatInterleaved()), flatLexicographic())); auto localView = basis1.localView(); - for_each_leaf_node(localView.tree(), [](auto const& node, auto const& tp) + Traversal::forEachLeafNode(localView.tree(), [&](auto const& rowNode, auto const& r) { - std::cout << tp << std::endl; + Traversal::forEachLeafNode(localView.tree(), [&](auto const& colNode, auto const& c) + { + std::cout << r << " , " << c << std::endl; + // std::cout << colNode.tp() << std::endl; + }); }); } diff --git a/test/DataTransferTest.hpp b/test/DataTransferTest.hpp index f209984c..38b78ecc 100644 --- a/test/DataTransferTest.hpp +++ b/test/DataTransferTest.hpp @@ -65,7 +65,7 @@ auto makeProblem(typename BasisCreator::GlobalBasis::GridView::Grid& grid, Fcts // interpolate given function to initial grid int k = 0; - for_each_leaf_node(localView.tree(), [&](auto const& node, auto tp) + Traversal::forEachLeafNode(localView.tree(), [&](auto const& node, auto tp) { auto gf = makeGridFunction(funcs[k], globalBasis.gridView()); AMDiS::interpolate(globalBasis, prob->solution(tp).coefficients(), gf, tp); @@ -86,7 +86,7 @@ double calcError(Problem const& prob, Fcts const& funcs) int k = 0; // interpolate given function onto reference vector - for_each_leaf_node(localView.tree(), [&](auto const& node, auto tp) + Traversal::forEachLeafNode(localView.tree(), [&](auto const& node, auto tp) { auto gf = makeGridFunction(funcs[k], globalBasis.gridView()); AMDiS::interpolate(globalBasis, ref, gf, tp); diff --git a/test/TreeContainerTest.cpp b/test/TreeContainerTest.cpp index f3b5d839..ddd05784 100644 --- a/test/TreeContainerTest.cpp +++ b/test/TreeContainerTest.cpp @@ -36,7 +36,7 @@ int main (int argc, char** argv) auto c3 = makeTreeContainer(tree, [&](auto const&) { return makeTreeContainer(tree); }); // fill 1d treeContainer with data - for_each_leaf_node(tree, [&](auto const& node, auto tp) { + Traversal::forEachLeafNode(tree, [&](auto const& node, auto tp) { c1[tp] = double(node.treeIndex()); }); @@ -45,8 +45,8 @@ int main (int argc, char** argv) AMDIS_TEST(c4 == c1); // fill 2d treeContainer with data - for_each_leaf_node(tree, [&](auto const& row_node, auto row_tp) { - for_each_leaf_node(tree, [&](auto const& col_node, auto col_tp) { + Traversal::forEachLeafNode(tree, [&](auto const& row_node, auto row_tp) { + Traversal::forEachLeafNode(tree, [&](auto const& col_node, auto col_tp) { c3[row_tp][col_tp] = double(row_node.treeIndex() + col_node.treeIndex()); }); }); -- GitLab