diff --git a/petab/v2/extensions/sciml.py b/petab/v2/extensions/sciml.py index 58189279..e5c2ba33 100644 --- a/petab/v2/extensions/sciml.py +++ b/petab/v2/extensions/sciml.py @@ -272,21 +272,22 @@ def from_config( sciml_config: SciMLConfig = config.extensions[C.EXT_ID_SCIML] # Neural network classes are constructed via pytorch for now to get - # the proper inputs - neural_networks = [ - NNModel.from_pytorch_module( - NNModelStandard.load_data( - _generate_path( - file_path=nn_config.location, - base_path=base_path, + # the proper inputs. Non-YAML formats are opaque — the file is assumed + # to contain a valid model and is not read here. + neural_networks = [] + for nn_id, nn_config in (sciml_config.neural_networks or {}).items(): + if nn_config.format.lower() == "yaml": + neural_networks.append( + NNModel.from_pytorch_module( + NNModelStandard.load_data( + _generate_path( + file_path=nn_config.location, + base_path=base_path, + ) + ).to_pytorch_module(), + nn_model_id=nn_id, ) - ).to_pytorch_module(), - nn_model_id=nn_id, - ) - for nn_id, nn_config in ( - sciml_config.neural_networks or {} - ).items() - ] + ) hybridization_tables = [ HybridizationTable.from_tsv(f, base_path) diff --git a/petab/v2/extensions/sciml_lint.py b/petab/v2/extensions/sciml_lint.py index 354274a8..06d716d8 100644 --- a/petab/v2/extensions/sciml_lint.py +++ b/petab/v2/extensions/sciml_lint.py @@ -16,9 +16,11 @@ def run(self, problem: core.Problem) -> lint.ValidationIssue | None: condition_targets = { c.target_id for ct in problem.conditions for c in ct.changes } + # Only YAML-format networks are loaded as NNModel objects nn_input_ids = { inp.input_id for nn in problem.extensions.sciml.neural_networks + if hasattr(nn, "inputs") for inp in nn.inputs } hyb_target_ids = { diff --git a/tests/v2/test_sciml.py b/tests/v2/test_sciml.py index 390605ec..7db715f0 100644 --- a/tests/v2/test_sciml.py +++ b/tests/v2/test_sciml.py @@ -142,3 +142,16 @@ def _get_test_problem(): def test_lint(): problem = _get_test_problem() assert problem.validate() == [] + + +def test_lint_equinox_network_format(): + """Linter accepts non-YAML formats without reading the network file.""" + problem = _get_test_problem() + # Replace the YAML network config with equinox format + sciml_cfg = problem.config.extensions["sciml"] + sciml_cfg.neural_networks["net1"] = NeuralNetConfig( + location="net1.py", + pre_initialization=False, + format="equinox", + ) + assert problem.validate() == []