Smith  0.1
Smith is an implicit thermal structural mechanics simulation code.
differentiable_physics.cpp
1 // Copyright (c) Lawrence Livermore National Security, LLC and
2 // other Smith Project Developers. See the top-level LICENSE file for
3 // details.
4 //
5 // SPDX-License-Identifier: (BSD-3-Clause)
6 
7 #include "gretl/data_store.hpp"
10 #include "smith/physics/mesh.hpp"
13 #include "gretl/upstream_state.hpp"
14 
15 namespace smith {
16 
20 gretl::State<int> make_milestone(const std::vector<FieldState>& states)
21 {
22  std::vector<gretl::StateBase> base_states;
23  for (const auto& s : states) {
24  base_states.push_back(s);
25  }
26 
27  auto milestone = states[0].create_state<int, int>(base_states);
28 
29  milestone.set_eval(
30  []([[maybe_unused]] const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<int>(0); });
31  milestone.set_vjp(
32  []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]] const gretl::DownstreamState& output) {});
33 
34  return milestone.finalize();
35 }
36 
37 // mesh, equation, fields, parameters, state advancer, solver
38 DifferentiablePhysics::DifferentiablePhysics(std::shared_ptr<Mesh> mesh, std::shared_ptr<gretl::DataStore> graph,
39  const FieldState& shape_disp, const std::vector<FieldState>& states,
40  const std::vector<FieldState>& params,
41  std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
42  const std::vector<std::string>& reaction_names)
43  : BasePhysics(mech_name, mesh, 0, 0.0, false), // the false is checkpoint_to_disk
44  checkpointer_(graph),
45  advancer_(advancer),
46  reaction_names_(reaction_names)
47 {
48  SLIC_ERROR_IF(states.size() == 0, "Must have a least 1 state for a mechanics.");
49  field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
50  for (size_t i = 0; i < states.size(); ++i) {
51  const auto& s = states[i];
52  field_states_.push_back(s);
53  initial_field_states_.push_back(s);
54  state_name_to_field_index_[s.get()->name()] = i;
55  state_names_.push_back(s.get()->name());
56  }
57 
58  for (size_t i = 0; i < params.size(); ++i) {
59  const auto& p = params[i];
60  field_params_.push_back(p);
61  param_name_to_field_index_[p.get()->name()] = i;
62  param_names_.push_back(p.get()->name());
63  }
64 
65  for (size_t i = 0; i < reaction_names_.size(); ++i) {
66  reaction_name_to_reaction_index_[reaction_names_[i]] = i;
67  }
68 
69  completeSetup();
70 }
71 
73 {
74  SLIC_ERROR_IF(field_states_.empty(), "Empty field state during completeSetup()");
75 }
76 
77 void DifferentiablePhysics::resetStates(int cycle, double time)
78 {
79  for (size_t i = 0; i < initial_field_states_.size(); ++i) {
80  field_states_[i] = initial_field_states_[i];
81  }
82  milestones_.clear();
83  checkpointer_->reset_graph();
84  time_ = time;
85  cycle_ = cycle;
86 }
87 
89 {
90  checkpointer_->finalize_graph();
91  checkpointer_->reset_for_backprop();
92  gretl_assert(checkpointer_->check_validity());
93 }
94 
95 std::vector<std::string> DifferentiablePhysics::stateNames() const { return state_names_; }
96 
97 std::vector<std::string> DifferentiablePhysics::parameterNames() const { return param_names_; }
98 
99 std::vector<std::string> DifferentiablePhysics::dualNames() const { return reaction_names_; }
100 
101 const FiniteElementState& DifferentiablePhysics::state([[maybe_unused]] const std::string& field_name) const
102 {
103  SLIC_ERROR_IF(
104  state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
105  std::format("Could not find field named {0} in mesh with tag \"{1}\" to get", field_name, mesh_->tag()));
106  size_t state_index = state_name_to_field_index_.at(field_name);
107  return *field_states_[state_index].get();
108 }
109 
110 const FiniteElementDual& DifferentiablePhysics::dual(const std::string& dual_name) const
111 {
112  SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(dual_name) == reaction_name_to_reaction_index_.end(),
113  std::format("Could not find dual named {0} in mesh with tag \"{1}\" to get", dual_name, mesh_->tag()));
114  size_t reaction_index = reaction_name_to_reaction_index_.at(dual_name);
115  SLIC_ERROR_IF(
116  reaction_index >= reaction_names_.size(),
117  "Dual reactions not correctly allocated yet, cannot get dual until after initializationStep is called.");
118 
119  TimeInfo time_info(time_prev_, dt_prev_, static_cast<size_t>(cycle_prev_));
120  reaction_states_ = advancer_->computeReactions(time_info, *field_shape_displacement_, field_states_, field_params_);
121  return *reaction_states_[reaction_index].get();
122 }
123 
124 FiniteElementState DifferentiablePhysics::loadCheckpointedState(const std::string& state_name, int cycle)
125 {
126  SLIC_ERROR_IF(cycle != cycle_,
127  std::format("Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
128  "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
129  cycle, cycle_));
130  return state(state_name);
131 }
132 
133 const FiniteElementState& DifferentiablePhysics::shapeDisplacement() const { return *field_shape_displacement_->get(); }
134 
135 const FiniteElementState& DifferentiablePhysics::parameter(std::size_t parameter_index) const
136 {
137  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
138  std::format("Parameter index {} requested, but only {} parameters exist in physics module {}.",
139  parameter_index, field_params_.size(), name_));
140  return *field_params_[parameter_index].get();
141 }
142 
143 const FiniteElementState& DifferentiablePhysics::parameter(const std::string& parameter_name) const
144 {
145  SLIC_ERROR_IF(
146  param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
147  std::format("Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name, mesh_->tag()));
148  size_t param_index = param_name_to_field_index_.at(parameter_name);
149  return parameter(param_index);
150 }
151 
152 void DifferentiablePhysics::setParameter(const size_t parameter_index, const FiniteElementState& parameter_state)
153 {
154  SLIC_ERROR_IF(parameter_index >= field_params_.size(),
155  std::format("Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
156  parameter_index, field_params_.size(), name_));
157  *field_params_[parameter_index].get() = parameter_state;
158 }
159 
160 void DifferentiablePhysics::setShapeDisplacement(const FiniteElementState& s) { *field_shape_displacement_->get() = s; }
161 
162 void DifferentiablePhysics::setState([[maybe_unused]] const std::string& field_name,
163  [[maybe_unused]] const FiniteElementState& s)
164 {
165  SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
166  std::format("Could not find field named {0} in mesh with tag {1} to set", field_name, mesh_->tag()));
167  size_t state_index = state_name_to_field_index_.at(field_name);
168  *field_states_[state_index].get() = s;
169  *initial_field_states_[state_index].get() = s;
170 }
171 
173  std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_dual)
174 {
175  for (auto string_dual_pair : string_to_dual) {
176  std::string field_name = string_dual_pair.first;
177  const smith::FiniteElementDual& dual = string_dual_pair.second;
178  SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
179  std::format("Could not find dual named {0} in mesh with tag {1}", field_name, mesh_->tag()));
180  size_t state_index = state_name_to_field_index_.at(field_name);
181  *field_states_[state_index].get_dual() += dual;
182  }
183 }
184 
186  std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
187 {
188  for (auto string_bc_pair : string_to_bc) {
189  std::string reaction_name = string_bc_pair.first;
190  const smith::FiniteElementState& reaction_dual = string_bc_pair.second;
191  SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
192  std::format("When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
193  reaction_name, mesh_->tag()));
194  size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
195  *reaction_states_[reaction_index].get_dual() += reaction_dual;
196  }
197 }
198 
199 const FiniteElementState& DifferentiablePhysics::adjoint([[maybe_unused]] const std::string& adjoint_name) const
200 {
201  // MRT, not implemented
202  SLIC_ERROR("What is the use case for asking for the adjoint solution field directly?");
203  return *adjoints_[0];
204 }
205 
207 {
208  if (cycle_ == 0) {
209  field_states_ = initial_field_states_;
210  milestones_.push_back(make_milestone(field_states_).step());
211  }
212 
213  cycle_prev_ = cycle_;
214  time_prev_ = time_;
215  dt_prev_ = dt;
216 
217  TimeInfo time_info(time_, dt, static_cast<size_t>(cycle_));
218  field_states_ = advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
219 
220  cycle_++;
221  time_ += dt;
222  milestones_.push_back(make_milestone(field_states_).step());
223 }
224 
226 {
227  --cycle_;
228  const gretl::Int milestone = milestones_[static_cast<size_t>(cycle_)];
229 
230  field_shape_displacement_->clear_dual();
231  for (auto& p : field_params_) {
232  p.clear_dual();
233  }
234 
235  gretl::Int current_step = checkpointer_->currentStep_;
236  while (milestone != current_step) {
237  checkpointer_->reverse_state();
238  current_step = checkpointer_->currentStep_;
239  }
240 
241  gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
242 
243  SLIC_ERROR_IF(field_states_.size() != upstreams.size(), "field states and upstream sizes do not match.");
244  // recreate the upstream field states with upstream step, field, and dual values.
245  for (size_t s = 0; s < upstreams.size(); ++s) {
246  field_states_[s].reset_step(upstreams[s].step_);
247  field_states_[s].set(upstreams[s].get<FEFieldPtr>());
248  field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
249  }
250 }
251 
253 {
254  return *field_params_[parameter_index].get_dual();
255 }
256 
258 {
259  return *field_shape_displacement_->get_dual();
260 }
261 
262 const std::unordered_map<std::string, const smith::FiniteElementDual&>
264 {
265  std::unordered_map<std::string, const smith::FiniteElementDual&> map;
266  for (auto& name : stateNames()) {
267  auto state_index = state_name_to_field_index_.at(name);
268  map.insert({name, *initial_field_states_[state_index].get_dual()});
269  }
270  return map;
271 }
272 
274 {
275  std::vector<FieldState> fields;
276  fields.insert(fields.end(), field_states_.begin(), field_states_.end());
277  fields.insert(fields.end(), field_params_.begin(), field_params_.end());
278  return fields;
279 }
280 
281 FieldState DifferentiablePhysics::getShapeDispFieldState() const { return *field_shape_displacement_; }
282 
283 } // namespace smith
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
overload
void reverseAdjointTimestep() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void completeSetup() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setShapeDisplacement(const FiniteElementState &s) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > parameterNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FieldState getShapeDispFieldState() const
Get the shape displacement FieldState.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > dualNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
virtual void resetAdjointStates() override
overload
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementDual & computeTimestepShapeSensitivity() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & parameter(std::size_t parameter_index) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > &params, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< std::string > &reaction_names={})
constructor
void setState(const std::string &state_name, const FiniteElementState &s) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & shapeDisplacement() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< FieldState > getFieldStatesAndParamStates() const
Get all the FieldStates... states first, parameters next.
virtual void advanceTimestep(double dt) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & state(const std::string &state_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > stateNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementDual & dual(const std::string &dual_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setParameter(const size_t parameter_index, const FiniteElementState &parameter_state) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Implementation of BasePhysics which uses FieldStates and gretl to track the computational graph,...
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
Definition: smith.cpp:36
gretl::State< int > make_milestone(const std::vector< FieldState > &states)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Definition: field_state.hpp:22
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information
Definition: common.hpp:18
Dual number struct (value plus gradient)
Definition: dual.hpp:28
Specifies interface for evaluating weak form residuals and their gradients.