Skip to content

Commit

Permalink
Merge pull request OpenMDAO#3387 from swryan/2711_set_input_defaults
Browse files Browse the repository at this point in the history
Added additional documentation for the `set_input_defaults` function
  • Loading branch information
swryan authored Nov 5, 2024
2 parents a4c18d8 + 2728228 commit 878c3a0
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 9 deletions.
11 changes: 11 additions & 0 deletions openmdao/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,17 @@ def _resolve_group_input_defaults(self, show_warnings=False):
abs_in2prom_info[tgt][tree_level] = \
_PromotesInfo(src_shape=src_shape, prom=prom,
promoted_from=self.pathname)
else:
# check for discrete targets
for tgt in prom2abs_in[prom]:
if tgt in self._discrete_inputs:
for key, val in meta.items():
# for discretes we can only set the value (not units/src_shape)
if key == 'val':
self._discrete_inputs[tgt] = val
elif key in ('units', 'src_shape'):
self._collect_error(f"{self.msginfo}: Cannot set '{key}={val}'"
f" for discrete variable '{tgt}'.")

meta.update(fullmeta)

Expand Down
103 changes: 103 additions & 0 deletions openmdao/core/tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,109 @@ def test_conflicting_val(self):
"promoted input 'x' with conflicting values for 'val'. Call <group>.set_input_defaults('x', val=?), "
"where <group> is the model to remove the ambiguity.")

def test_set_input_defaults_discrete(self):
import math
import openmdao.api as om

density = {
'steel': 7.85, # g/cm^3
'aluminum': 2.7 # g/cm^3
}

class SquarePlate(om.ExplicitComponent):
"""
Calculate the weight of a square plate.
material is a discrete input (default: steel)
"""

def setup(self):
self.add_discrete_input('material', 'steel')

self.add_input('length', 1.0, units='cm')
self.add_input('width', 1.0, units='cm')
self.add_input('thickness', 1.0, units='cm')

self.add_output('weight', 1.0, units='g')

def compute(self, inputs, outputs, discrete_inputs, discrete_outputs):
length = inputs['length']
width = inputs['width']
thickness = inputs['thickness']
material = discrete_inputs['material']

outputs['weight'] = length * width * thickness * density[material]

class CirclePlate(om.ExplicitComponent):
"""
Calculate the weight of a circular plate.
material is a discrete input (default: aluminum)
"""

def setup(self):
self.add_discrete_input('material', 'aluminum')

self.add_input('radius', 1.0, units='cm')
self.add_input('thickness', 1.0, units='g')

self.add_output('weight', 1.0, units='g')

def compute(self, inputs, outputs, discrete_inputs, discrete_output):
radius = inputs['radius']
thickness = inputs['thickness']
material = discrete_inputs['material']

outputs['weight'] = math.pi * radius**2 * thickness * density[material]

#
# first check that we get errors when using invalid args to set defaults on a discrete
#
p = om.Problem()
model = p.model

model.add_subsystem('square', SquarePlate(), promotes_inputs=['material'])
model.add_subsystem('circle', CirclePlate(), promotes_inputs=['material'])

# setting input defaults for units/src_shape is not valid for a discrete and will generate errors
model.set_input_defaults('material', 'steel', units='kg', src_shape=(1,))
expect_errors = [
f"Collected errors for problem '{p._get_inst_id()}':",
" <model> <class Group>: Cannot set 'units=kg' for discrete variable 'circle.material'.",
" <model> <class Group>: Cannot set 'src_shape=(1,)' for discrete variable 'circle.material'.",
" <model> <class Group>: Cannot set 'units=kg' for discrete variable 'square.material'.",
" <model> <class Group>: Cannot set 'src_shape=(1,)' for discrete variable 'square.material'.",
]

with self.assertRaises(Exception) as cm:
p.setup()

err_msgs = cm.exception.args[0].split('\n')
for err_msg in expect_errors:
self.assertTrue(err_msg in err_msgs,
err_msg + ' not found in:\n' + cm.exception.args[0])

#
# now make sure that setting just the default value for a discrete works as expected
#
p = om.Problem()
model = p.model

model.add_subsystem('square', SquarePlate(), promotes_inputs=['material'])
model.add_subsystem('circle', CirclePlate(), promotes_inputs=['material'])

model.set_input_defaults('material', 'steel')

p.setup()
p.run_model()

inputs = model.list_inputs(return_format='dict', out_stream=None)
self.assertEqual(inputs['square.material']['val'], 'steel')
self.assertEqual(inputs['circle.material']['val'], 'steel')

assert_near_equal(p['square.weight'], 7.85)
assert_near_equal(p['circle.weight'], 24.66150233, 1e-6)


class MultComp(om.ExplicitComponent):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"# Running Your Models\n",
"\n",
"- [Setting and Getting Component Variables](set_get.ipynb)\n",
"- [Using 'set_input_defaults'](set_input_defaults.ipynb)\n",
"- [Setup Your Model](setup.ipynb)\n",
"- [Run Your Model](run_model.ipynb)\n",
"- [Run a Driver](run_driver.ipynb)"
Expand All @@ -30,7 +31,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.12.6"
},
"orphan": true
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,9 @@
"to the same name as an output, then again the framework will connect all of those inputs to an\n",
"`_auto_ivc` output. If, however, there is any difference between the units or values of any of those inputs,\n",
"then you must tell the framework what units and/or values to use when creating the corresponding\n",
"`_auto_ivc` output. You do this by calling the `set_input_defaults` function using the promoted\n",
"`_auto_ivc` output. You do this by calling the [set_input_defaults](./set_input_defaults.ipynb) function using the promoted\n",
"input name on a Group that contains all of the promoted inputs.\n",
"\n",
"```{eval-rst}\n",
" .. automethod:: openmdao.core.group.Group.set_input_defaults\n",
" :noindex:\n",
"```\n",
"\n",
"Below is an example of what you'll see if you do *not* call `set_input_defaults` to disambiguate\n",
"your units and/or values:"
]
Expand Down Expand Up @@ -1088,7 +1083,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
"version": "3.12.6"
},
"orphan": true
},
Expand Down
Loading

0 comments on commit 878c3a0

Please sign in to comment.