7 #include "gretl/data_store.hpp"
13 #include "gretl/upstream_state.hpp"
22 std::vector<gretl::StateBase> base_states;
23 for (
const auto& s : states) {
24 base_states.push_back(s);
27 auto milestone = states[0].create_state<int,
int>(base_states);
30 []([[maybe_unused]]
const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { output.set<
int>(0); });
32 []([[maybe_unused]] gretl::UpstreamStates& inputs, [[maybe_unused]]
const gretl::DownstreamState& output) {});
34 return milestone.finalize();
39 const FieldState& shape_disp,
const std::vector<FieldState>& states,
40 const std::vector<FieldState>& params,
41 std::shared_ptr<StateAdvancer> advancer, std::string mech_name,
42 const std::vector<std::string>& reaction_names)
46 reaction_names_(reaction_names)
48 SLIC_ERROR_IF(states.size() == 0,
"Must have a least 1 state for a mechanics.");
49 field_shape_displacement_ = std::make_unique<FieldState>(shape_disp);
50 for (
size_t i = 0; i < states.size(); ++i) {
51 const auto& s = states[i];
52 field_states_.push_back(s);
53 initial_field_states_.push_back(s);
54 state_name_to_field_index_[s.get()->name()] = i;
55 state_names_.push_back(s.get()->name());
58 for (
size_t i = 0; i < params.size(); ++i) {
59 const auto& p = params[i];
60 field_params_.push_back(p);
61 param_name_to_field_index_[p.get()->name()] = i;
62 param_names_.push_back(p.get()->name());
65 for (
size_t i = 0; i < reaction_names_.size(); ++i) {
66 reaction_name_to_reaction_index_[reaction_names_[i]] = i;
74 SLIC_ERROR_IF(field_states_.empty(),
"Empty field state during completeSetup()");
79 for (
size_t i = 0; i < initial_field_states_.size(); ++i) {
80 field_states_[i] = initial_field_states_[i];
83 checkpointer_->reset_graph();
90 checkpointer_->finalize_graph();
91 checkpointer_->reset_for_backprop();
92 gretl_assert(checkpointer_->check_validity());
104 state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
105 std::format(
"Could not find field named {0} in mesh with tag \"{1}\" to get", field_name,
mesh_->tag()));
106 size_t state_index = state_name_to_field_index_.at(field_name);
107 return *field_states_[state_index].get();
112 SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(dual_name) == reaction_name_to_reaction_index_.end(),
113 std::format(
"Could not find dual named {0} in mesh with tag \"{1}\" to get", dual_name,
mesh_->tag()));
114 size_t reaction_index = reaction_name_to_reaction_index_.at(dual_name);
116 reaction_index >= reaction_names_.size(),
117 "Dual reactions not correctly allocated yet, cannot get dual until after initializationStep is called.");
119 TimeInfo time_info(time_prev_, dt_prev_,
static_cast<size_t>(cycle_prev_));
120 reaction_states_ = advancer_->computeReactions(time_info, *field_shape_displacement_, field_states_, field_params_);
121 return *reaction_states_[reaction_index].get();
127 std::format(
"Due to checkpointing restrictions in smith::Mechanics, cannot ask for an arbitrary "
128 "checkpointed cycle, asking for cycle {}, but physics is at cycle {}",
130 return state(state_name);
137 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
138 std::format(
"Parameter index {} requested, but only {} parameters exist in physics module {}.",
139 parameter_index, field_params_.size(),
name_));
140 return *field_params_[parameter_index].get();
146 param_name_to_field_index_.find(parameter_name) == param_name_to_field_index_.end(),
147 std::format(
"Could not find parameter named {0} in mesh with tag \"{1}\" to get", parameter_name,
mesh_->tag()));
148 size_t param_index = param_name_to_field_index_.at(parameter_name);
154 SLIC_ERROR_IF(parameter_index >= field_params_.size(),
155 std::format(
"Parameter '{}' requested when only '{}' parameters exist in physics module '{}'",
156 parameter_index, field_params_.size(),
name_));
157 *field_params_[parameter_index].get() = parameter_state;
165 SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
166 std::format(
"Could not find field named {0} in mesh with tag {1} to set", field_name,
mesh_->tag()));
167 size_t state_index = state_name_to_field_index_.at(field_name);
168 *field_states_[state_index].get() = s;
169 *initial_field_states_[state_index].get() = s;
173 std::unordered_map<std::string, const smith::FiniteElementDual&> string_to_dual)
175 for (
auto string_dual_pair : string_to_dual) {
176 std::string field_name = string_dual_pair.first;
178 SLIC_ERROR_IF(state_name_to_field_index_.find(field_name) == state_name_to_field_index_.end(),
179 std::format(
"Could not find dual named {0} in mesh with tag {1}", field_name,
mesh_->tag()));
180 size_t state_index = state_name_to_field_index_.at(field_name);
181 *field_states_[state_index].get_dual() +=
dual;
186 std::unordered_map<std::string, const smith::FiniteElementState&> string_to_bc)
188 for (
auto string_bc_pair : string_to_bc) {
189 std::string reaction_name = string_bc_pair.first;
191 SLIC_ERROR_IF(reaction_name_to_reaction_index_.find(reaction_name) == reaction_name_to_reaction_index_.end(),
192 std::format(
"When calling setDualAdjointBcs, could not find reaction named {0} in mesh with tag {1}",
193 reaction_name,
mesh_->tag()));
194 size_t reaction_index = reaction_name_to_reaction_index_.at(reaction_name);
195 *reaction_states_[reaction_index].get_dual() += reaction_dual;
202 SLIC_ERROR(
"What is the use case for asking for the adjoint solution field directly?");
209 field_states_ = initial_field_states_;
218 field_states_ = advancer_->advanceState(time_info, *field_shape_displacement_, field_states_, field_params_);
228 const gretl::Int milestone = milestones_[
static_cast<size_t>(
cycle_)];
230 field_shape_displacement_->clear_dual();
231 for (
auto& p : field_params_) {
235 gretl::Int current_step = checkpointer_->currentStep_;
236 while (milestone != current_step) {
237 checkpointer_->reverse_state();
238 current_step = checkpointer_->currentStep_;
241 gretl::UpstreamStates upstreams(*checkpointer_, checkpointer_->upstreamSteps_[milestone]);
243 SLIC_ERROR_IF(field_states_.size() != upstreams.size(),
"field states and upstream sizes do not match.");
245 for (
size_t s = 0; s < upstreams.size(); ++s) {
246 field_states_[s].reset_step(upstreams[s].step_);
247 field_states_[s].set(upstreams[s].get<FEFieldPtr>());
248 field_states_[s].set_dual(upstreams[s].get_dual<FEDualPtr, FEFieldPtr>());
254 return *field_params_[parameter_index].get_dual();
259 return *field_shape_displacement_->get_dual();
262 const std::unordered_map<std::string, const smith::FiniteElementDual&>
265 std::unordered_map<std::string, const smith::FiniteElementDual&> map;
267 auto state_index = state_name_to_field_index_.at(
name);
268 map.insert({
name, *initial_field_states_[state_index].get_dual()});
275 std::vector<FieldState> fields;
276 fields.insert(fields.end(), field_states_.begin(), field_states_.end());
277 fields.insert(fields.end(), field_params_.begin(), field_params_.end());
This is the abstract base class for a generic forward solver.
std::string name_
Name of the physics module.
std::shared_ptr< smith::Mesh > mesh_
The primary mesh.
int cycle_
Current cycle (forward pass time iteration count)
std::string name() const
Return the name of the physics.
virtual double time() const
Get the current forward-solution time.
double time_
Current time for the forward pass.
virtual int cycle() const
Get the current forward-solution cycle iteration number.
std::vector< const smith::FiniteElementState * > adjoints_
List of finite element adjoint states associated with this physics module.
void resetStates(int cycle=0, double time=0.0) override
overload
void reverseAdjointTimestep() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void completeSetup() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setShapeDisplacement(const FiniteElementState &s) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > parameterNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FieldState getShapeDispFieldState() const
Get the shape displacement FieldState.
const FiniteElementState & adjoint(const std::string &adjoint_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setDualAdjointBcs(std::unordered_map< std::string, const smith::FiniteElementState & > string_to_bc) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > dualNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FiniteElementDual computeTimestepSensitivity(size_t parameter_index) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
virtual void resetAdjointStates() override
overload
void setAdjointLoad(std::unordered_map< std::string, const smith::FiniteElementDual & > string_to_dual) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementDual & computeTimestepShapeSensitivity() override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & parameter(std::size_t parameter_index) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
DifferentiablePhysics(std::shared_ptr< Mesh > mesh, std::shared_ptr< gretl::DataStore > graph, const FieldState &shape_disp, const std::vector< FieldState > &states, const std::vector< FieldState > ¶ms, std::shared_ptr< StateAdvancer > advancer, std::string physics_name, const std::vector< std::string > &reaction_names={})
constructor
void setState(const std::string &state_name, const FiniteElementState &s) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & shapeDisplacement() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
FiniteElementState loadCheckpointedState(const std::string &state_name, int cycle) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< FieldState > getFieldStatesAndParamStates() const
Get all the FieldStates... states first, parameters next.
virtual void advanceTimestep(double dt) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementState & state(const std::string &state_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
std::vector< std::string > stateNames() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const FiniteElementDual & dual(const std::string &dual_name) const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
const std::unordered_map< std::string, const smith::FiniteElementDual & > computeInitialConditionSensitivity() const override
This is an overloaded member function, provided for convenience. It differs from the above function o...
void setParameter(const size_t parameter_index, const FiniteElementState ¶meter_state) override
This is an overloaded member function, provided for convenience. It differs from the above function o...
Class for encapsulating the dual vector space of a finite element space (i.e. the space of linear for...
Class for encapsulating the critical MFEM components of a primal finite element field.
Implementation of BasePhysics which uses FieldStates and gretl to track the computational graph,...
Smith mesh class which assists in constructing the appropriate parallel mfem meshes and registering a...
Accelerator functionality.
gretl::State< int > make_milestone(const std::vector< FieldState > &states)
gretl-function to create a dummy-state which records all states and params of interest to the mechani...
gretl::State< FEFieldPtr, FEDualPtr > FieldState
typedef
Reaction class which is a names combination of a weak form and a set of dirichlet constrained nodes.
Interface and implementations for advancing from one step to the next. Typically these are time integ...
struct storing time and timestep information
Dual number struct (value plus gradient)