Serac  0.1
Serac is an implicit thermal strucural mechanics simulation code.
state_manager.cpp
1 // Copyright (c) 2019-2024, Lawrence Livermore National Security, LLC and
2 // other Serac Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
8 
9 #include "axom/core.hpp"
10 
11 namespace serac {
12 
13 // Initialize StateManager's static members - these will be fully initialized in StateManager::initialize
14 std::unordered_map<std::string, axom::sidre::MFEMSidreDataCollection> StateManager::datacolls_;
15 std::unordered_map<std::string, std::unique_ptr<FiniteElementState>> StateManager::shape_displacements_;
16 bool StateManager::is_restart_ = false;
17 axom::sidre::DataStore* StateManager::ds_ = nullptr;
18 std::string StateManager::output_dir_ = "";
19 std::unordered_map<std::string, mfem::ParGridFunction*> StateManager::named_states_;
20 std::unordered_map<std::string, mfem::ParGridFunction*> StateManager::named_duals_;
21 
22 double StateManager::newDataCollection(const std::string& name, const std::optional<int> cycle_to_load)
23 {
24  SLIC_ERROR_ROOT_IF(!ds_, "Cannot construct a DataCollection without a DataStore");
25  std::string coll_name = name + "_datacoll";
26 
27  auto global_grp = ds_->getRoot()->createGroup(coll_name + "_global");
28  auto bp_index_grp = global_grp->createGroup("blueprint_index/" + coll_name);
29  auto domain_grp = ds_->getRoot()->createGroup(coll_name);
30 
31  // Needs to be configured to own the mesh data so all mesh data is saved to datastore/output file
32  constexpr bool owns_mesh_data = true;
33  auto [iter, _] = datacolls_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
34  std::forward_as_tuple(coll_name, bp_index_grp, domain_grp, owns_mesh_data));
35  auto& datacoll = iter->second;
36  datacoll.SetComm(MPI_COMM_WORLD);
37 
38  datacoll.SetPrefixPath(output_dir_);
39 
40  if (cycle_to_load) {
41  // NOTE: Load invalidates previous Sidre pointers
42  datacoll.Load(*cycle_to_load);
43  datacoll.SetGroupPointers(ds_->getRoot()->getGroup(coll_name + "_global/blueprint_index/" + coll_name),
44  ds_->getRoot()->getGroup(coll_name));
45  SLIC_ERROR_ROOT_IF(datacoll.GetBPGroup()->getNumGroups() == 0,
46  "Loaded datastore is empty, was the datastore created on a "
47  "different number of nodes?");
48 
49  datacoll.UpdateStateFromDS();
50  datacoll.UpdateMeshAndFieldsFromDS();
51 
52  // Functional needs the nodal grid function and neighbor data in the mesh
53 
54  // Determine if the existing nodal grid function is discontinuous. This
55  // indicates that the mesh is periodic and the new nodal grid function must also
56  // be discontinuous.
57  bool is_discontinuous = false;
58  auto nodes = mesh(name).GetNodes();
59  if (nodes) {
60  is_discontinuous = nodes->FESpace()->FEColl()->GetContType() == mfem::FiniteElementCollection::DISCONTINUOUS;
61  SLIC_WARNING_ROOT_IF(
62  is_discontinuous,
63  "Periodic mesh detected! This will only work on translational periodic surfaces for vector H1 fields and "
64  "has not been thoroughly tested. Proceed at your own risk.");
65  }
66 
67  // This mfem call ensures the mesh contains an H1 grid function describing nodal
68  // cordinates. The parameters do the following:
69  // 1. Sets the order of the mesh to p = 1
70  // 2. Uses the existing continuity of the mesh finite element space (periodic meshes are discontinuous)
71  // 3. Uses the spatial dimension as the mesh dimension (i.e. it is not a lower dimension manifold)
72  // 4. Uses nodal instead of VDIM ordering (i.e. xxxyyyzzz instead of xyzxyzxyz)
73  mesh(name).SetCurvature(1, is_discontinuous, -1, mfem::Ordering::byNODES);
74 
75  // Sidre will destruct the nodal grid function instead of the mesh
76  mesh(name).SetNodesOwner(false);
77 
78  // Generate the face neighbor information in the mesh. This is needed by the face restriction
79  // operators used by Functional
80  mesh(name).ExchangeFaceNbrData();
81 
82  // Construct and store the shape displacement fields and sensitivities associated with this mesh
83  constructShapeFields(name);
84 
85  } else {
86  datacoll.SetCycle(0); // Iteration counter
87  datacoll.SetTime(0.0); // Simulation time
88  }
89 
90  return datacoll.GetTime();
91 }
92 
93 void StateManager::loadCheckpointedStates(int cycle_to_load,
94  std::vector<std::reference_wrapper<FiniteElementState>> states_to_load)
95 {
96  std::string mesh_name = collectionID(&states_to_load.begin()->get().mesh());
97 
98  std::string coll_name = mesh_name + "_datacoll";
99 
100  axom::sidre::MFEMSidreDataCollection previous_datacoll(coll_name);
101 
102  previous_datacoll.SetComm(states_to_load.begin()->get().mesh().GetComm());
103  previous_datacoll.SetPrefixPath(output_dir_);
104  previous_datacoll.Load(cycle_to_load);
105 
106  for (auto state : states_to_load) {
107  SLIC_ERROR_ROOT_IF(collectionID(&state.get().mesh()) != mesh_name,
108  "Loading FiniteElementStates from two different meshes at one time is not allowed.");
109  mfem::ParGridFunction* datacoll_owned_grid_function = previous_datacoll.GetParField(state.get().name());
110 
111  state.get().setFromGridFunction(*datacoll_owned_grid_function);
112  }
113 }
114 
115 void StateManager::initialize(axom::sidre::DataStore& ds, const std::string& output_directory)
116 {
117  // If the global object has already been initialized, clear it out
118  if (ds_) {
119  reset();
120  }
121  ds_ = &ds;
122  output_dir_ = output_directory;
123  if (output_directory.empty()) {
124  SLIC_ERROR_ROOT(
125  "DataCollection output directory cannot be empty - this will result in problems if executables are run in "
126  "parallel");
127  }
128 }
129 
131 {
132  return *shape_displacements_[mesh_tag];
133 }
134 
136 {
137  SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
138  auto mesh_tag = collectionID(&state.mesh());
139  SLIC_ERROR_ROOT_IF(named_states_.find(state.name()) != named_states_.end(),
140  axom::fmt::format("StateManager already contains a state named '{}'", state.name()));
141  auto& datacoll = datacolls_.at(mesh_tag);
142  const std::string name = state.name();
143  mfem::ParGridFunction* grid_function;
144  if (is_restart_) {
145  grid_function = datacoll.GetParField(name);
146  state.setFromGridFunction(*grid_function);
147  } else {
148  SLIC_ERROR_ROOT_IF(datacoll.HasField(name),
149  axom::fmt::format("StateManager already given a field named '{0}'", name));
150 
151  // Create a new grid function with unallocated data. This will be managed by sidre.
152  grid_function = new mfem::ParGridFunction(&state.space(), static_cast<double*>(nullptr));
153  datacoll.RegisterField(name, grid_function);
154  state.setFromGridFunction(*grid_function);
155  }
156  named_states_[name] = grid_function;
157 }
158 
159 FiniteElementState StateManager::newState(const mfem::ParFiniteElementSpace& space, const std::string& state_name)
160 {
161  std::string mesh_tag = collectionID(space.GetParMesh());
162 
163  SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
164  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
165  axom::fmt::format("Mesh tag '{}' not found in the data store", mesh_tag));
166  SLIC_ERROR_ROOT_IF(named_states_.find(state_name) != named_states_.end(),
167  axom::fmt::format("StateManager already contains a state named '{}'", state_name));
168  auto state = FiniteElementState(space, state_name);
169  storeState(state);
170  return state;
171 }
172 
174 {
175  SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
176  auto mesh_tag = collectionID(&dual.mesh());
177  SLIC_ERROR_ROOT_IF(named_duals_.find(dual.name()) != named_duals_.end(),
178  axom::fmt::format("StateManager already contains a state named '{}'", dual.name()));
179  auto& datacoll = datacolls_.at(mesh_tag);
180  const std::string name = dual.name();
181  mfem::ParGridFunction* grid_function;
182  if (is_restart_) {
183  grid_function = datacoll.GetParField(name);
184  std::unique_ptr<mfem::HypreParVector> true_dofs(grid_function->GetTrueDofs());
185  dual = *true_dofs;
186  } else {
187  SLIC_ERROR_ROOT_IF(datacoll.HasField(name),
188  axom::fmt::format("StateManager already given a field named '{0}'", name));
189 
190  // Create a new grid function with unallocated data. This will be managed by sidre.
191  grid_function = new mfem::ParGridFunction(&dual.space(), static_cast<double*>(nullptr));
192  datacoll.RegisterField(name, grid_function);
193  std::unique_ptr<mfem::HypreParVector> true_dofs(grid_function->GetTrueDofs());
194  dual = *true_dofs;
195  }
196  named_duals_[name] = grid_function;
197 }
198 
199 FiniteElementDual StateManager::newDual(const mfem::ParFiniteElementSpace& space, const std::string& dual_name)
200 {
201  std::string mesh_tag = collectionID(space.GetParMesh());
202 
203  SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
204  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
205  axom::fmt::format("Mesh tag '{}' not found in the data store", mesh_tag));
206  SLIC_ERROR_ROOT_IF(named_duals_.find(dual_name) != named_duals_.end(),
207  axom::fmt::format("StateManager already contains a dual named '{}'", dual_name));
208  auto dual = FiniteElementDual(space, dual_name);
209  storeDual(dual);
210  return dual;
211 }
212 
213 void StateManager::save(const double t, const int cycle, const std::string& mesh_tag)
214 {
215  SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
216  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
217  axom::fmt::format("Mesh tag '{}' not found in the data store", mesh_tag));
218  auto& datacoll = datacolls_.at(mesh_tag);
219  std::string file_path = axom::utilities::filesystem::joinPath(datacoll.GetPrefixPath(), datacoll.GetCollectionName());
220  SLIC_INFO_ROOT(
221  axom::fmt::format("Saving data collection at time: '{}' and cycle: '{}' to path: '{}'", t, cycle, file_path));
222 
223  datacoll.SetTime(t);
224  datacoll.SetCycle(cycle);
225  datacoll.Save();
226 }
227 
228 mfem::ParMesh& StateManager::setMesh(std::unique_ptr<mfem::ParMesh> pmesh, const std::string& mesh_tag)
229 {
230  // Determine if the existing nodal grid function is discontinuous. This
231  // indicates that the mesh is periodic and the new nodal grid function must also
232  // be discontinuous.
233  bool is_discontinuous = false;
234  auto nodes = pmesh->GetNodes();
235  if (nodes) {
236  is_discontinuous = nodes->FESpace()->FEColl()->GetContType() == mfem::FiniteElementCollection::DISCONTINUOUS;
237  SLIC_WARNING_ROOT_IF(
238  is_discontinuous,
239  "Periodic mesh detected! This will only work on translational periodic surfaces for vector H1 fields and "
240  "has not been thoroughly tested. Proceed at your own risk.");
241  }
242 
243  // This mfem call ensures the mesh contains an H1 grid function describing nodal
244  // cordinates. The parameters do the following:
245  // 1. Sets the order of the mesh to p = 1
246  // 2. Uses the existing continuity of the mesh finite element space (periodic meshes are discontinuous)
247  // 3. Uses the spatial dimension as the mesh dimension (i.e. it is not a lower dimension manifold)
248  // 4. Uses nodal instead of VDIM ordering (i.e. xxxyyyzzz instead of xyzxyzxyz)
249  pmesh->SetCurvature(1, is_discontinuous, -1, mfem::Ordering::byNODES);
250 
251  // Sidre will destruct the nodal grid function instead of the mesh
252  pmesh->SetNodesOwner(false);
253 
254  newDataCollection(mesh_tag);
255  auto& datacoll = datacolls_.at(mesh_tag);
256  datacoll.SetMesh(pmesh.release());
257  datacoll.SetOwnData(true);
258 
259  // Functional needs the nodal grid function and neighbor data in the mesh
260  auto& new_pmesh = mesh(mesh_tag);
261 
262  // Generate the face neighbor information in the mesh. This is needed by the face restriction
263  // operators used by Functional
264  new_pmesh.ExchangeFaceNbrData();
265 
266  // We must construct the shape fields here as the mesh did not exist during the newDataCollection call
267  // for the non-restart case
268  constructShapeFields(mesh_tag);
269 
270  return new_pmesh;
271 }
272 
273 void StateManager::constructShapeFields(const std::string& mesh_tag)
274 {
275  // Construct the shape displacement field associated with this mesh
276  auto& new_mesh = mesh(mesh_tag);
277 
278  if (new_mesh.Dimension() == 2) {
279  shape_displacements_[mesh_tag] =
280  std::make_unique<FiniteElementState>(new_mesh, SHAPE_DIM_2, mesh_tag + "_shape_displacement");
281  } else if (new_mesh.Dimension() == 3) {
282  shape_displacements_[mesh_tag] =
283  std::make_unique<FiniteElementState>(new_mesh, SHAPE_DIM_3, mesh_tag + "_shape_displacement");
284  } else {
285  SLIC_ERROR_ROOT(axom::fmt::format("Mesh of dimension {} given, only dimensions 2 or 3 are available in Serac.",
286  new_mesh.Dimension()));
287  }
288 
289  storeState(*shape_displacements_[mesh_tag]);
290 
291  *shape_displacements_[mesh_tag] = 0.0;
292 }
293 
294 mfem::ParMesh& StateManager::mesh(const std::string& mesh_tag)
295 {
296  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
297  axom::fmt::format("Mesh tag \"{}\" not found in the data store", mesh_tag));
298  auto mesh = datacolls_.at(mesh_tag).GetMesh();
299  SLIC_ERROR_ROOT_IF(!mesh, "The datacollection does not contain a mesh object");
300  return static_cast<mfem::ParMesh&>(*mesh);
301 }
302 
303 std::string StateManager::collectionID(const mfem::ParMesh* pmesh)
304 {
305  for (auto& [name, datacoll] : datacolls_) {
306  if (datacoll.GetMesh() == pmesh) {
307  return name;
308  }
309  }
310  SLIC_ERROR_ROOT("The mesh has not been registered with StateManager");
311  return {};
312 }
313 
314 int StateManager::cycle(std::string mesh_tag)
315 {
316  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
317  axom::fmt::format("Mesh tag \"{}\" not found in the data store", mesh_tag));
318  return datacolls_.at(mesh_tag).GetCycle();
319 }
320 
321 double StateManager::time(std::string mesh_tag)
322 {
323  SLIC_ERROR_ROOT_IF(datacolls_.find(mesh_tag) == datacolls_.end(),
324  axom::fmt::format("Mesh tag \"{}\" not found in the data store", mesh_tag));
325  return datacolls_.at(mesh_tag).GetTime();
326 }
327 
328 } // namespace serac
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
void setFromGridFunction(const mfem::ParGridFunction &grid_function)
Initialize the true vector in the FiniteElementState based on an input grid function.
mfem::ParFiniteElementSpace & space()
Returns a non-owning reference to the internal FESpace.
std::string name() const
Returns the name of the FEState (field)
mfem::ParMesh & mesh()
Returns a non-owning reference to the internal mesh object.
static void loadCheckpointedStates(int cycle_to_load, std::vector< std::reference_wrapper< FiniteElementState >> states_to_load)
loads the finite element states from a previously checkpointed cycle
static void initialize(axom::sidre::DataStore &ds, const std::string &output_directory)
Initializes the StateManager with a sidre DataStore (into which state will be written/read)
static void storeDual(FiniteElementDual &dual)
Store a pre-constructed finite element dual in the state manager.
static int cycle(std::string mesh_tag)
Get the current cycle (iteration number) from the underlying datacollection.
static double time(std::string mesh_tag)
Get the current simulation time from the underlying datacollection.
static FiniteElementState newState(FunctionSpace space, const std::string &state_name, const std::string &mesh_tag)
Factory method for creating a new FEState object.
static std::string collectionID(const mfem::ParMesh *pmesh)
Returns the datacollection ID for a given mesh.
static void reset()
Resets the underlying global datacollection object.
static FiniteElementDual newDual(FunctionSpace space, const std::string &dual_name, const std::string &mesh_tag)
Factory method for creating a new FEDual object.
static mfem::ParMesh & mesh(const std::string &mesh_tag)
Returns a non-owning reference to mesh held by StateManager.
static void storeState(FiniteElementState &state)
Store a pre-constructed finite element state in the state manager.
static void save(const double t, const int cycle, const std::string &mesh_tag)
Updates the Conduit Blueprint state in the datastore and saves to a file.
static FiniteElementState & shapeDisplacement(const std::string &mesh_tag)
Get the shape displacement finite element state.
static mfem::ParMesh & setMesh(std::unique_ptr< mfem::ParMesh > pmesh, const std::string &mesh_tag)
Gives ownership of mesh to StateManager.
Accelerator functionality.
Definition: serac.cpp:38
constexpr H1< SHAPE_ORDER, 2 > SHAPE_DIM_2
Function space for shape displacement on dimension 2 meshes.
constexpr H1< SHAPE_ORDER, 3 > SHAPE_DIM_3
Function space for shape displacement on dimension 2 meshes.
dual(double, T) -> dual< T >
class template argument deduction guide for type dual.
This file contains the declaration of the StateManager class.
Dual number struct (value plus gradient)
Definition: dual.hpp:29