Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions petab/v2/extensions/sciml.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,22 @@ def from_config(
sciml_config: SciMLConfig = config.extensions[C.EXT_ID_SCIML]

# Neural network classes are constructed via pytorch for now to get
# the proper inputs
neural_networks = [
NNModel.from_pytorch_module(
NNModelStandard.load_data(
_generate_path(
file_path=nn_config.location,
base_path=base_path,
# the proper inputs. Non-YAML formats are opaque — the file is assumed
# to contain a valid model and is not read here.
neural_networks = []
for nn_id, nn_config in (sciml_config.neural_networks or {}).items():
if nn_config.format.lower() == "yaml":
neural_networks.append(
NNModel.from_pytorch_module(
NNModelStandard.load_data(
_generate_path(
file_path=nn_config.location,
base_path=base_path,
)
).to_pytorch_module(),
nn_model_id=nn_id,
)
).to_pytorch_module(),
nn_model_id=nn_id,
)
for nn_id, nn_config in (
sciml_config.neural_networks or {}
).items()
]
)

hybridization_tables = [
HybridizationTable.from_tsv(f, base_path)
Expand Down
2 changes: 2 additions & 0 deletions petab/v2/extensions/sciml_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def run(self, problem: core.Problem) -> lint.ValidationIssue | None:
condition_targets = {
c.target_id for ct in problem.conditions for c in ct.changes
}
# Only YAML-format networks are loaded as NNModel objects
nn_input_ids = {
inp.input_id
for nn in problem.extensions.sciml.neural_networks
if hasattr(nn, "inputs")
for inp in nn.inputs
}
hyb_target_ids = {
Expand Down
13 changes: 13 additions & 0 deletions tests/v2/test_sciml.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,16 @@ def _get_test_problem():
def test_lint():
problem = _get_test_problem()
assert problem.validate() == []


def test_lint_equinox_network_format():
"""Linter accepts non-YAML formats without reading the network file."""
problem = _get_test_problem()
# Replace the YAML network config with equinox format
sciml_cfg = problem.config.extensions["sciml"]
sciml_cfg.neural_networks["net1"] = NeuralNetConfig(
location="net1.py",
pre_initialization=False,
format="equinox",
)
assert problem.validate() == []
Loading