"""ProtFlow runner for PottsMPNN.
This module integrates the command-line PottsMPNN YAML workflows into
ProtFlow. It supports the two upstream scripts that expose a ``--config``
interface:
- ``sample_seqs.py`` for sequence design from backbone structures.
- ``energy_prediction.py`` for mutation-energy and deep-mutational-scan
scoring.
The runner writes script-specific YAML files, dispatches one command per
generated config through a :class:`~protflow.jobstarters.JobStarter`, and
collects PottsMPNN FASTA or CSV outputs back into a
:class:`~protflow.poses.Poses` dataframe.
Configuration
-------------
The runner reads default executable paths from the ProtFlow config:
``POTTSMPNN_DIR``
Path to the local PottsMPNN checkout. Commands are executed from this
directory so relative checkpoint paths from upstream YAML examples work.
``POTTSMPNN_PYTHON``
Python interpreter from the PottsMPNN environment.
``POTTSMPNN_PRE_CMD``
Optional shell prefix used to activate modules or environments before each
command.
Parameter Objects
-----------------
Use :class:`SampleSequencePottsMPNNParams` with ``sample_seqs.py`` and
:class:`EnergyPredictionPottsMPNNParams` with ``energy_prediction.py``. These
typed dataclasses expose PottsMPNN model and inference fields directly, so IDEs
can autocomplete nested attributes such as ``params.model.check_path`` and
``params.inference.num_samples``.
Pose-specific Values
--------------------
Wrap a dataframe column name in :class:`PoseCol` to fill a parameter from
``Poses.df``. Parameters ending in ``*_custom`` are converted into temporary
JSON files and can still be batched. Other pose-specific parameters require one
config per input pose.
Examples
--------
Design two sequences per backbone:
>>> from protflow.poses import Poses
>>> from protflow.tools import PottsMPNN, SampleSequencePottsMPNNParams
>>> poses = Poses(poses=["backbone_a.pdb", "backbone_b.pdb"], work_dir="work")
>>> params = SampleSequencePottsMPNNParams()
>>> params.inference.num_samples = 2
>>> params.inference.temperature = 0.1
>>> params.inference.optimization_mode = "none"
>>> poses = PottsMPNN().run(poses=poses, prefix="potts_design", params=params)
Score mutations from a CSV file:
>>> from protflow.tools import EnergyPredictionPottsMPNNParams
>>> params = EnergyPredictionPottsMPNNParams(mutant_csv="mutations.csv")
>>> poses = PottsMPNN().run(
... poses=poses,
... prefix="potts_energy",
... script="energy_prediction",
... params=params,
... )
"""
from __future__ import annotations
import copy
import json
import logging
import os
import shlex
import shutil
from dataclasses import dataclass, field, fields, is_dataclass
from glob import glob
from typing import Any, ClassVar
import pandas as pd
import yaml
from protflow import load_config_path, require_config
from protflow.jobstarters import JobStarter, split_list
from protflow.poses import Poses
from protflow.runners import (
Runner,
RunnerOutput,
options_flags_to_string,
parse_generic_options,
prepend_cmd,
)
# scripts with upstream --config YAML entrypoints.
SUPPORTED_CONFIG_SCRIPTS = {"sample_seqs", "energy_prediction"}
[docs]
class PottsMPNN(Runner):
"""Run PottsMPNN command-line scripts from ProtFlow.
Parameters
----------
python_path : str, optional
Python interpreter used to execute PottsMPNN. If omitted, the value is
loaded from ``POTTSMPNN_PYTHON`` in the ProtFlow config.
pottsmpnn_dir : str, optional
Path to the PottsMPNN checkout. If omitted, the value is loaded from
``POTTSMPNN_DIR``.
pre_cmd : str, optional
Shell prefix prepended to every command, commonly used to activate a
conda environment or cluster module. Defaults to ``POTTSMPNN_PRE_CMD``.
jobstarter : JobStarter, optional
Default jobstarter used when :meth:`run` is called without one.
Attributes
----------
name : str
Runner name used for job names and cached score files.
index_layers : int
Default merge index depth. The active value is selected per script in
:meth:`run` because ``sample_seqs.py`` appends sample indices while
``energy_prediction.py`` keeps one row per input pose.
pottsmpnn_dir : str
Resolved PottsMPNN checkout path.
python_path : str
Resolved PottsMPNN Python interpreter.
pre_cmd : str
Resolved shell prefix.
Notes
-----
Only upstream scripts with a ``--config`` YAML interface are supported.
The runner currently supports ``sample_seqs.py`` and
``energy_prediction.py``.
"""
[docs]
def __init__(
self,
python_path: str | None = None,
pottsmpnn_dir: str | None = None,
pre_cmd: str | None = None,
jobstarter: JobStarter | None = None,
) -> None:
"""Initialize the runner and resolve PottsMPNN configuration.
Parameters
----------
python_path : str, optional
Python interpreter used to run PottsMPNN.
pottsmpnn_dir : str, optional
Local PottsMPNN checkout.
pre_cmd : str, optional
Optional shell prefix for environment activation.
jobstarter : JobStarter, optional
Default jobstarter for this runner instance.
"""
# config required
config = require_config()
# setup config paths
self.pottsmpnn_dir = str(pottsmpnn_dir or load_config_path(config, "POTTSMPNN_DIR"))
self.python_path = str(python_path or load_config_path(config, "POTTSMPNN_PYTHON"))
self.pre_cmd = pre_cmd or load_config_path(config, "POTTSMPNN_PRE_CMD", is_pre_cmd=True)
# setup runner state
self.jobstarter = jobstarter
self.name = "pottsmpnn"
self.index_layers = 0
[docs]
def __str__(self) -> str:
"""Return the short runner name.
Returns
-------
str
The literal runner name ``"pottsmpnn"``.
"""
return self.name
[docs]
def run(
self,
poses: Poses,
prefix: str,
jobstarter: JobStarter | None = None,
script: str | None = "sample_seqs",
params: SampleSequencePottsMPNNParams | EnergyPredictionPottsMPNNParams | None = None,
options: str | None = None,
pose_options: str | list[str] | None = None,
include_scores: list[str] | None = None,
overwrite: bool = False,
) -> Poses:
"""Run PottsMPNN and merge collected results into ``poses``.
Parameters
----------
poses : Poses
Input structures to pass to PottsMPNN. The ``poses`` column must
contain PDB paths and ``poses_description`` is used as the upstream
PottsMPNN structure identifier.
prefix : str
Unique run prefix used to create the runner work directory and
prefixed output score columns.
jobstarter : JobStarter, optional
Jobstarter for this call. If omitted, the runner falls back to the
instance jobstarter and then ``poses.default_jobstarter``.
script : str, optional
Script alias or path. Supported aliases are ``"sample_seqs"`` and
``"energy_prediction"``.
params : SampleSequencePottsMPNNParams or EnergyPredictionPottsMPNNParams, optional
Typed parameter object used to generate YAML configs. If omitted,
defaults are created for the selected script.
options : str, optional
Extra command-line options passed to the upstream script. ``--config``
is ignored because config files are managed by the runner.
pose_options : str or list of str, optional
Unsupported for PottsMPNN. Use :class:`PoseCol` fields in ``params``
for pose-specific settings.
include_scores : list of str, optional
Reserved for API consistency with other runners. PottsMPNN collectors
currently load the standard output fields.
overwrite : bool, optional
If ``True``, remove previous runner-owned outputs and rerun jobs.
Returns
-------
Poses
The input ``Poses`` object with PottsMPNN score columns merged in.
Raises
------
ValueError
If ``pose_options`` are supplied or the params object does not match
the selected script.
NotImplementedError
If ``script`` is not one of the supported config-based scripts.
RuntimeError
If PottsMPNN runs but no score rows can be collected.
"""
# sanity
if pose_options is not None:
raise ValueError("PottsMPNN uses YAML configs; use PoseCol params instead of pose_options.")
# sanitize script_path and params:
script_path, script_key = self._resolve_script(script)
index_layers = 1 if script_key == "sample_seqs" else 0
if params is None:
params = (
SampleSequencePottsMPNNParams()
if script_key == "sample_seqs"
else EnergyPredictionPottsMPNNParams()
)
if params.script != script_key:
raise ValueError(f"Params for '{params.script}' cannot be used with script '{script_key}'.")
if script_key == "sample_seqs":
_check_sample_descriptions(poses)
# setup run directory and jobstarter
work_dir, jobstarter = self.generic_run_setup(
poses=poses,
prefix=prefix,
jobstarters=[jobstarter, self.jobstarter, poses.default_jobstarter],
)
logging.info("Running %s in %s on %d poses", self, work_dir, len(poses))
# scorefile reuse shortcut
scorefile = os.path.join(work_dir, f"{self.name}_scores.{poses.storage_format}")
if (scores := self.check_for_existing_scorefile(scorefile=scorefile, overwrite=overwrite)) is not None:
outputs = RunnerOutput(
poses=poses,
results=scores,
prefix=prefix,
index_layers=index_layers
)
return outputs.return_poses()
# cleanup previous outputs
if overwrite:
self._cleanup_previous_outputs(work_dir)
# prepare config files
batched, config_files = params_to_config(
poses=poses,
n_batches=jobstarter.max_cores,
work_dir=work_dir,
params=params,
)
# build commands
cmds = self._build_commands(
script=script_path,
config_files=config_files,
options=options
)
# prepend configured environment command
if self.pre_cmd:
cmds = prepend_cmd(cmds=cmds, pre_cmd=self.pre_cmd)
# execute jobs
jobstarter.start(
cmds=cmds,
jobname=self.name,
wait=True,
output_path=work_dir
)
# collect and validate scores
scores = collect_scores(
work_dir=work_dir,
script=script_key,
batched=batched,
include_scores=include_scores
)
scores = _fill_missing_locations(scores=scores, poses=poses, index_layers=index_layers)
if len(scores.index) == 0:
raise RuntimeError(f"{self}: collect_scores returned no rows. Check runner output directory: {work_dir}")
# save scores and merge back into poses
self.save_runner_scorefile(scores=scores, scorefile=scorefile)
outputs = RunnerOutput(
poses=poses,
results=scores,
prefix=prefix,
index_layers=index_layers
)
return outputs.return_poses()
def _resolve_script(self, script: str | None) -> tuple[str, str]:
"""Resolve a script alias or path to an executable PottsMPNN script.
Parameters
----------
script : str, optional
Script alias, script filename, or absolute path.
Returns
-------
tuple of str
Absolute script path and normalized script key.
Raises
------
NotImplementedError
If the script does not provide a supported ``--config`` interface.
FileNotFoundError
If the script cannot be found directly or inside ``pottsmpnn_dir``.
"""
# normalize script alias
script = script or "sample_seqs"
script_key = _script_key(script)
# restrict to config-based scripts
if script_key not in SUPPORTED_CONFIG_SCRIPTS:
raise NotImplementedError(
"Only PottsMPNN scripts with a '--config' YAML interface are supported: "
f"{sorted(SUPPORTED_CONFIG_SCRIPTS)}"
)
# search direct path and checkout-relative path
candidates = [str(script)]
if not str(script).endswith(".py"):
candidates.append(f"{script}.py")
candidates.extend(
os.path.join(self.pottsmpnn_dir, candidate)
for candidate in list(candidates)
)
# return first valid script path.
for candidate in candidates:
if os.path.isfile(candidate):
return os.path.abspath(candidate), script_key
raise FileNotFoundError(f"Could not find PottsMPNN script '{script}' in {self.pottsmpnn_dir}.")
def _prep_pottsmpnn_opts(self, raw_opts: str | None) -> str:
"""Normalize extra command-line options for PottsMPNN.
Parameters
----------
raw_opts : str, optional
User-supplied options string.
Returns
-------
str
Parsed option string with runner-managed ``--config`` removed.
"""
# parse generic CLI options
if not raw_opts:
return ""
opts, flags = parse_generic_options(raw_opts, "", sep="--")
# config files are managed by the runner
if "config" in opts:
logging.warning("Ignoring user-specified PottsMPNN --config option: %s", opts["config"])
del opts["config"]
return options_flags_to_string(opts, flags, sep="--")
def _build_commands(self, script: str, config_files: list[str], options: str | None) -> list[str]:
"""Build one PottsMPNN command per generated config file.
Parameters
----------
script : str
Absolute path to the upstream PottsMPNN script.
config_files : list of str
YAML config files generated for this run.
options : str, optional
Extra command-line options shared by all configs.
Returns
-------
list of str
Shell commands ready for a :class:`~protflow.jobstarters.JobStarter`.
"""
# options are shared across generated configs
cli_args = self._prep_pottsmpnn_opts(options)
cmds = [
self.write_cmd(script=script, config_path=config_file, cli_args=cli_args)
for config_file in config_files
]
return cmds
[docs]
def write_cmd(self, script: str, config_path: str, cli_args: str = "") -> str:
"""Format the shell command for a single PottsMPNN config.
Parameters
----------
script : str
Absolute path to the upstream PottsMPNN script.
config_path : str
YAML config passed as ``--config``.
cli_args : str, optional
Additional parsed command-line arguments.
Returns
-------
str
Command that runs from the PottsMPNN checkout.
"""
# run from PottsMPNN checkout so relative model paths work
cmd = (
f"cd {shlex.quote(self.pottsmpnn_dir)}; "
f"{shlex.quote(self.python_path)} {shlex.quote(script)} --config {shlex.quote(config_path)}"
)
return f"{cmd} {cli_args}" if cli_args else cmd
def _cleanup_previous_outputs(self, work_dir: str) -> None:
"""Remove previous runner-owned outputs inside the work directory.
Parameters
----------
work_dir : str
Runner work directory created for this prefix.
"""
# remove only files/directories inside runner work_dir
if not os.path.isdir(work_dir):
return
for path in glob(os.path.join(work_dir, "*")):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
[docs]
class PoseCol(str):
"""Mark a PottsMPNN parameter as coming from ``Poses.df``.
``PoseCol`` behaves like a string at runtime, but it signals the config
writer to replace the value with data from the named dataframe column.
Parameters
----------
col_name : str
Name of the column in ``poses.df``.
Examples
--------
Use a dataframe column to write a per-pose fixed-position JSON file:
>>> from protflow.tools import PoseCol, SampleSequencePottsMPNNParams
>>> params = SampleSequencePottsMPNNParams()
>>> params.inference.fixed_positions_custom = PoseCol("fixed_positions")
"""
[docs]
def __new__(cls, col_name: str) -> "PoseCol":
"""Create a string subclass that preserves ``PoseCol`` type checks.
Parameters
----------
col_name : str
Referenced ``poses.df`` column.
Returns
-------
PoseCol
Marker value used by config generation.
"""
return super().__new__(cls, col_name)
@property
def col_name(self) -> str:
"""Return the referenced poses dataframe column name.
Returns
-------
str
Column name stored by this marker.
"""
return str(self)
[docs]
@dataclass
class PottsMPNNModelParams:
"""Store PottsMPNN model configuration fields.
Attributes
----------
check_path : str or PoseCol
Path to the PottsMPNN checkpoint. Relative paths are resolved from the
PottsMPNN checkout because commands run in ``POTTSMPNN_DIR``.
hidden_dim : int or PoseCol
Hidden dimension used to construct the model.
edge_features : int or PoseCol
Number of edge features.
potts_dim : int or PoseCol
Potts representation dimension.
num_layers : int or PoseCol
Number of encoder and decoder layers.
num_edges : int or PoseCol
Number of structural neighbors.
vocab : int or PoseCol
Vocabulary size expected by the checkpoint.
"""
check_path: str | PoseCol = "vanilla_model_weights/pottsmpnn_20.pt"
hidden_dim: int | PoseCol = 128
edge_features: int | PoseCol = 128
potts_dim: int | PoseCol = 400
num_layers: int | PoseCol = 3
num_edges: int | PoseCol = 48
vocab: int | PoseCol = 21
[docs]
@dataclass
class SampleSequenceInferenceParams:
"""Store ``sample_seqs.py`` inference configuration fields.
Attributes
----------
num_samples : int or PoseCol
Number of sequences sampled per input structure.
temperature : float or PoseCol
Autoregressive sampling temperature.
noise : float or PoseCol
Coordinate noise added during inference.
skip_gaps : bool or PoseCol
Whether upstream parsing should skip structural gaps.
fix_decoding_order : bool or PoseCol
Whether to use a fixed decoding order.
decoding_order_offset : int or PoseCol
Offset applied to the fixed decoding order.
optimization_mode : str or PoseCol
Optimization mode, typically ``"none"``, ``"potts"``, or ``"nodes"``.
optimization_temperature : float or PoseCol
Temperature used during sequence optimization.
binding_energy_optimization : str or PoseCol
Binding-energy optimization mode, typically ``"none"``, ``"both"``,
or ``"only"``.
binding_energy_json : str or None or PoseCol
Path to upstream binding-energy partition JSON.
binding_energy_cutoff : float or PoseCol
Interface cutoff in Angstrom used for binding-energy optimization.
optimize_pdb : bool or PoseCol
Optimize sequences read from the input PDB files.
optimize_fasta : str or PoseCol
FASTA file whose sequences should be optimized.
write_pdb : bool or PoseCol
Whether PottsMPNN should write redesigned PDB files.
fixed_positions_json, pssm_json, omit_AA_json, bias_AA_json, tied_positions_json, bias_by_res_json : str or PoseCol
Paths to upstream ProteinMPNN-style constraint and bias JSON files.
fixed_positions_custom, pssm_custom, omit_AA_custom, bias_AA_custom, tied_positions_custom, bias_by_res_custom : str or PoseCol
ProtFlow helpers that can be populated from ``PoseCol`` values and are
written to temporary JSON files before running PottsMPNN.
omit_AAs : list of str or PoseCol
Amino acids globally omitted from sampling.
pssm_threshold : float or PoseCol
PSSM threshold passed to PottsMPNN.
pssm_multi : float or PoseCol
PSSM mixing weight.
pssm_log_odds_flag : bool or PoseCol
Whether to use PSSM log odds.
pssm_bias_flag : bool or PoseCol
Whether to use PSSM biasing.
"""
num_samples: int | PoseCol = 1
temperature: float | PoseCol = 0.1
noise: float | PoseCol = 0.0
skip_gaps: bool | PoseCol = False
fix_decoding_order: bool | PoseCol = True
decoding_order_offset: int | PoseCol = 0
optimization_mode: str | PoseCol = "potts"
optimization_temperature: float | PoseCol = 0.0
binding_energy_optimization: str | PoseCol = "none"
binding_energy_json: str | None | PoseCol = None
binding_energy_cutoff: float | PoseCol = 8
optimize_pdb: bool | PoseCol = False
optimize_fasta: str | PoseCol = ""
write_pdb: bool | PoseCol = True
fixed_positions_json: str | PoseCol = ""
pssm_json: str | PoseCol = ""
omit_AA_json: str | PoseCol = ""
bias_AA_json: str | PoseCol = ""
tied_positions_json: str | PoseCol = ""
tied_epistasis: bool | PoseCol = False
bias_by_res_json: str | PoseCol = ""
fixed_positions_custom: str | PoseCol = ""
pssm_custom: str | PoseCol = ""
omit_AA_custom: str | PoseCol = ""
bias_AA_custom: str | PoseCol = ""
tied_positions_custom: str | PoseCol = ""
bias_by_res_custom: str | PoseCol = ""
omit_AAs: list[str] | PoseCol = field(default_factory=list)
pssm_threshold: float | PoseCol = 0.0
pssm_multi: float | PoseCol = 0.0
pssm_log_odds_flag: bool | PoseCol = False
pssm_bias_flag: bool | PoseCol = False
batchable_params = [
"fixed_positions_custom",
"pssm_custom",
"omit_AA_custom",
"bias_AA_custom",
"tied_positions_custom",
"bias_by_res_custom",
]
[docs]
@dataclass
class SampleSequenceParams:
"""Store top-level ``sample_seqs.py`` YAML configuration fields.
Attributes
----------
dev : str or PoseCol
Device string passed to PottsMPNN, usually ``"cuda"`` or ``"cpu"``.
out_dir : str or PoseCol
Output directory written by the runner.
out_name : str or PoseCol
Output basename written by the runner.
input_list : str or PoseCol
Path to the generated PottsMPNN input list.
input_dir : str or PoseCol
Directory containing staged input PDB files.
chain_dict_json : str or None or PoseCol
Optional upstream chain-design JSON path.
chain_dict_custom : str or PoseCol
ProtFlow helper for pose-specific chain dictionaries written to JSON.
model : PottsMPNNModelParams
Model checkpoint and architecture settings.
inference : SampleSequenceInferenceParams
Sequence-sampling and optimization settings.
"""
dev: str | PoseCol = "cuda"
out_dir: str | PoseCol = ""
out_name: str | PoseCol = ""
input_list: str | PoseCol = ""
input_dir: str | PoseCol = ""
chain_dict_json: str | None | PoseCol = None
chain_dict_custom: str | PoseCol = ""
model: PottsMPNNModelParams = field(default_factory=PottsMPNNModelParams)
inference: SampleSequenceInferenceParams = field(default_factory=SampleSequenceInferenceParams)
batchable_params = ["chain_dict_custom"]
[docs]
@dataclass
class EnergyPredictionInferenceParams:
"""Store ``energy_prediction.py`` inference configuration fields.
Attributes
----------
ddG : bool or PoseCol
If ``True``, output mutant minus wild-type energies.
mean_norm : bool or PoseCol
Whether to mean-center predicted mutation energies.
max_tokens : int or PoseCol
Token budget used by upstream batching.
filter : bool or PoseCol
Whether to return only mutants with experimental energies.
binding_energy_json : str or None or PoseCol
Path to upstream binding-energy partition JSON.
binding_energy_custom : str or PoseCol
ProtFlow helper for pose-specific binding-energy JSON payloads.
binding_energy_cutoff : float or PoseCol
Interface cutoff in Angstrom.
skip_gaps : bool or PoseCol
Whether upstream parsing should skip structural gaps.
noise : float or PoseCol
Coordinate noise added during inference.
chain_dict : str or None or PoseCol
Optional upstream chain dictionary setting.
chain_ranges : str or None or PoseCol
Optional JSON path used by upstream heatmap plotting.
exclude_chains : list of str or None or PoseCol
Chains excluded from mutation-energy scoring.
"""
ddG: bool | PoseCol = True
mean_norm: bool | PoseCol = False
max_tokens: int | PoseCol = 20000
filter: bool | PoseCol = False
binding_energy_json: str | None | PoseCol = None
binding_energy_custom: str | PoseCol = ""
binding_energy_cutoff: float | PoseCol = 8
skip_gaps: bool | PoseCol = False
noise: float | PoseCol = 0.0
chain_dict: str | None | PoseCol = None
chain_ranges: str | None | PoseCol = None
exclude_chains: list[str] | None | PoseCol = None
batchable_params = ["binding_energy_custom"]
[docs]
@dataclass
class EnergyPredictionParams:
"""Store top-level ``energy_prediction.py`` YAML configuration fields.
Attributes
----------
dev : str or PoseCol
Device string passed to PottsMPNN, usually ``"cuda"`` or ``"cpu"``.
out_dir : str or PoseCol
Output directory written by the runner.
out_name : str or PoseCol
Output basename written by the runner.
input_list : str or PoseCol
Path to the generated PottsMPNN input list.
input_dir : str or PoseCol
Directory containing staged input PDB files.
mutant_fasta : str or None or PoseCol
FASTA file defining mutants to score. If both mutant inputs are
``None``, upstream PottsMPNN performs a deep mutational scan.
mutant_csv : str or None or PoseCol
CSV file defining mutants to score.
model : PottsMPNNModelParams
Model checkpoint and architecture settings.
inference : EnergyPredictionInferenceParams
Mutation-energy prediction settings.
"""
dev: str | PoseCol = "cuda"
out_dir: str | PoseCol = ""
out_name: str | PoseCol = ""
input_list: str | PoseCol = ""
input_dir: str | PoseCol = ""
mutant_fasta: str | None | PoseCol = None
mutant_csv: str | None | PoseCol = None
model: PottsMPNNModelParams = field(default_factory=PottsMPNNModelParams)
inference: EnergyPredictionInferenceParams = field(default_factory=EnergyPredictionInferenceParams)
[docs]
class PottsMPNNParamsBase:
"""Share YAML config helpers across typed PottsMPNN parameter classes.
This mixin is inherited by :class:`SampleSequencePottsMPNNParams` and
:class:`EnergyPredictionPottsMPNNParams`. It is not intended to be
instantiated directly.
Attributes
----------
script : str
Normalized upstream script key expected by :class:`PottsMPNN`.
"""
script: ClassVar[str]
def _compile_attrs_dict(self, flat: bool = False) -> dict[str, Any]:
"""Return parameter values as a nested or flattened dictionary.
Parameters
----------
flat : bool, optional
If ``True``, return dot-separated keys for nested fields.
Returns
-------
dict
Parameter values including ProtFlow ``*_custom`` helper fields.
"""
if flat:
return {".".join(path): value for path, value, _ in _iter_param_values(self)}
return _params_to_dict(self, include_custom=True)
def _non_batchable_attrs(self) -> list[Any]:
"""Return parameter values that prevent batched execution.
Returns
-------
list
Values from fields that are not declared as batch-compatible.
"""
return [value for _, value, is_batchable in _iter_param_values(self) if not is_batchable]
def _params_are_batchable(self) -> bool:
"""Return whether all ``PoseCol`` values can be materialized per batch.
Returns
-------
bool
``True`` if all pose-specific values are stored in batch-compatible
fields.
"""
return not any(isinstance(value, PoseCol) for value in self._non_batchable_attrs())
[docs]
def resolve_pose_cols_batched(self, poses: Poses, n_batches: int, work_dir: str) -> list[str]:
"""Write batched configs while materializing batch-compatible ``PoseCol`` values.
Parameters
----------
poses : Poses
Input poses whose dataframe columns may be referenced by
:class:`PoseCol`.
n_batches : int
Maximum number of config batches to create.
work_dir : str
Runner work directory.
Returns
-------
list of str
Paths to generated YAML config files.
"""
# validate PoseCol references
self._check_pose_cols(poses)
# split poses into job-sized batches
batches = _split_pose_dataframe(poses=poses, n_batches=n_batches)
config_files = []
for i, pose_batch in enumerate(batches, start=1):
# stage batch input PDBs and metadata
batch_params = copy.deepcopy(self)
batch_dir = os.path.abspath(os.path.join(work_dir, f"batch_{i}"))
batch_input_dir = os.path.join(batch_dir, "input_pdbs")
json_dir = os.path.join(batch_dir, "json_files")
os.makedirs(batch_input_dir, exist_ok=True)
os.makedirs(json_dir, exist_ok=True)
for pose_path in pose_batch["poses"].to_list():
shutil.copy(pose_path, os.path.join(batch_input_dir, os.path.basename(pose_path)))
input_list_fn = os.path.join(batch_dir, "input_list.txt")
pose_descriptions = pose_batch["poses_description"].to_list()
_write_lines(input_list_fn, pose_descriptions)
batch_params.out_dir = os.path.join(batch_dir, "outputs") #pylint: disable=w0201
batch_params.out_name = f"batch_{i}" #pylint: disable=w0201
batch_params.input_list = input_list_fn #pylint: disable=w0201
batch_params.input_dir = batch_input_dir #pylint: disable=w0201
# materialize PoseCol-backed JSON files
for path, value, is_batchable in _iter_param_values(self):
if not isinstance(value, PoseCol):
continue
if not is_batchable:
raise ValueError(f"Internal error: non-batchable PoseCol reached batched setup: {'.'.join(path)}")
json_path = os.path.join(json_dir, f"batch_{i}_{'_'.join(path)}.json")
_write_json(json_path, {row["poses_description"]: row[str(value)] for _, row in pose_batch.iterrows()})
_set_nested_attr(batch_params, _custom_path_to_json_path(path), json_path)
# write final YAML
config_path = os.path.join(batch_dir, "config.yaml")
batch_params.to_yaml(config_path)
config_files.append(config_path)
return config_files
[docs]
def resolve_pose_cols(self, poses: Poses, n_batches: int, work_dir: str) -> tuple[bool, list[str]]:
"""Return batch mode and generated config paths.
Parameters
----------
poses : Poses
Input poses used for config generation.
n_batches : int
Maximum number of batch configs to write.
work_dir : str
Runner work directory.
Returns
-------
tuple
``(batched, config_files)`` where ``batched`` records whether the
configs represent pose batches.
"""
# choose batched mode only when every PoseCol can be encoded per batch
if self._params_are_batchable():
return True, self.resolve_pose_cols_batched(poses=poses, n_batches=n_batches, work_dir=work_dir)
return False, self.resolve_pose_cols_unbatched(poses=poses, work_dir=work_dir)
[docs]
def resolve_pose_cols_unbatched(self, poses: Poses, work_dir: str) -> list[str]:
"""Write one config per pose for non-batchable ``PoseCol`` values.
Parameters
----------
poses : Poses
Input poses used for config generation.
work_dir : str
Runner work directory.
Returns
-------
list of str
Paths to generated YAML config files.
"""
# validate PoseCol references
self._check_pose_cols(poses)
# setup output directories
config_dir = os.path.join(work_dir, "config_files")
input_list_dir = os.path.join(work_dir, "input_lists")
json_dir = os.path.join(work_dir, "json_files")
output_dir = os.path.join(work_dir, "outputs")
for path in (config_dir, input_list_dir, json_dir, output_dir):
os.makedirs(path, exist_ok=True)
# write one config per pose
config_files = []
for pose in poses:
pose_params = copy.deepcopy(self)
desc = pose["poses_description"]
input_list = os.path.join(input_list_dir, f"{desc}_input_list.txt")
_write_lines(input_list, [desc])
pose_params.out_dir = output_dir #pylint: disable=w0201
pose_params.out_name = desc #pylint: disable=w0201
pose_params.input_list = input_list #pylint: disable=w0201
pose_params.input_dir = os.path.dirname(pose["poses"]) #pylint: disable=w0201
# resolve PoseCol values for this pose
for path, value, _ in _iter_param_values(self):
if not isinstance(value, PoseCol):
continue
if path[-1].endswith("_custom"):
json_path = os.path.join(json_dir, f"{desc}_{'_'.join(path)}.json")
_write_json(json_path, {desc: pose[str(value)]})
_set_nested_attr(pose_params, _custom_path_to_json_path(path), json_path)
else:
_set_nested_attr(pose_params, path, pose[str(value)])
config_path = os.path.join(config_dir, f"{desc}_config.yaml")
pose_params.to_yaml(config_path)
config_files.append(config_path)
return config_files
[docs]
def to_yaml(self, out_path: str) -> None:
"""Write this parameter set as a PottsMPNN YAML config.
Parameters
----------
out_path : str
Destination YAML path.
"""
# exclude *_custom helpers from upstream YAML
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
with open(out_path, "w", encoding="UTF-8") as handle:
yaml.safe_dump(_params_to_dict(self, include_custom=False), handle, sort_keys=False)
def _check_pose_cols(self, poses: Poses) -> None:
"""Validate that all ``PoseCol`` references exist in ``poses.df``.
Parameters
----------
poses : Poses
Input poses whose dataframe is checked.
Raises
------
KeyError
If any referenced dataframe column is missing.
"""
# collect missing dataframe columns once for a clear error
missing = sorted({str(value) for _, value, _ in _iter_param_values(self) if isinstance(value, PoseCol)} - set(poses.df.columns))
if missing:
raise KeyError(f"PoseCol column(s) not found in poses.df: {missing}")
[docs]
@dataclass
class SampleSequencePottsMPNNParams(SampleSequenceParams, PottsMPNNParamsBase):
"""Typed params for ``sample_seqs.py``.
Use this class when running :meth:`PottsMPNN.run` with
``script="sample_seqs"`` or the default script.
Examples
--------
>>> params = SampleSequencePottsMPNNParams()
>>> params.inference.num_samples = 8
>>> params.inference.temperature = 0.2
"""
script: ClassVar[str] = "sample_seqs"
[docs]
@dataclass
class EnergyPredictionPottsMPNNParams(EnergyPredictionParams, PottsMPNNParamsBase):
"""Typed params for ``energy_prediction.py``.
Use this class when running :meth:`PottsMPNN.run` with
``script="energy_prediction"``.
Examples
--------
>>> params = EnergyPredictionPottsMPNNParams(mutant_csv="mutations.csv")
>>> params.inference.ddG = True
"""
script: ClassVar[str] = "energy_prediction"
[docs]
def params_to_config(
poses: Poses,
n_batches: int,
work_dir: str,
params: SampleSequencePottsMPNNParams | EnergyPredictionPottsMPNNParams,
) -> tuple[bool, list[str]]:
"""Generate PottsMPNN config files and report whether they are batched.
Parameters
----------
poses : Poses
Input poses for this run.
n_batches : int
Maximum number of batch configs to write.
work_dir : str
Runner work directory.
params : SampleSequencePottsMPNNParams or EnergyPredictionPottsMPNNParams
Typed params object for the selected upstream script.
Returns
-------
tuple
``(batched, config_files)`` where ``config_files`` contains YAML paths.
Raises
------
ValueError
If ``params`` is ``None``.
"""
# params are required so script-specific defaults are explicit
if params is None:
raise ValueError("PottsMPNN params must not be None.")
return params.resolve_pose_cols(poses=poses, n_batches=n_batches, work_dir=work_dir)
[docs]
def fasta_to_df(fasta_file: str, desc_col_name: str = "description", seq_col_name: str = "sequence") -> pd.DataFrame:
"""Parse a FASTA file into description and sequence columns.
Parameters
----------
fasta_file : str
FASTA file to parse.
desc_col_name : str, optional
Name of the output description column.
seq_col_name : str, optional
Name of the output sequence column.
Returns
-------
pandas.DataFrame
Dataframe with one row per FASTA record.
"""
# parse records manually to avoid adding another parser dependency
records = []
description = None
seq_chunks: list[str] = []
with open(fasta_file, "r", encoding="UTF-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
if line.startswith(">"):
if description is not None:
records.append({desc_col_name: description, seq_col_name: "".join(seq_chunks)})
description = line[1:].strip()
seq_chunks = []
else:
seq_chunks.append(line)
if description is not None:
records.append({desc_col_name: description, seq_col_name: "".join(seq_chunks)})
return pd.DataFrame(records, columns=[desc_col_name, seq_col_name])
[docs]
def collect_scores_sample_seqs(work_dir: str, batched: bool, include_scores: list[str] | None = None) -> pd.DataFrame:
"""Collect ``sample_seqs.py`` FASTA and loss outputs into score rows.
The collector reads raw sampled FASTA files, optional optimized FASTA
files, and ``*_av_loss.csv`` metrics. It writes one FASTA file per
returned row under ``output_fastas`` so :class:`RunnerOutput` can treat
sampled sequences as pose locations.
Parameters
----------
work_dir : str
Runner work directory containing generated configs and PottsMPNN
outputs.
batched : bool
Whether the run used batched configs.
include_scores : list of str, optional
Reserved for API consistency; currently ignored.
Returns
-------
pandas.DataFrame
Score rows with normalized ``description`` values, sequences, optional
PottsMPNN loss metrics, and per-sequence FASTA ``location`` paths.
Raises
------
FileNotFoundError
If no sample FASTA outputs are found.
"""
# include_scores is unused because sample outputs are scalar and FASTA-based
del include_scores
# discover output directories from generated configs
configs = _load_run_configs(work_dir)
out_dirs = _output_dirs_from_configs(configs, batched=batched)
av_loss_files = _glob_output_files(out_dirs, "*_av_loss.csv")
raw_seq_files = [
path for path in _glob_output_files(out_dirs, "*.fasta")
if "_optimized_" not in os.path.splitext(os.path.basename(path))[0]
]
optimized_seq_files = _glob_output_files(out_dirs, "*_optimized_*.fasta")
if not raw_seq_files and not optimized_seq_files:
raise FileNotFoundError(f"No PottsMPNN FASTA outputs found under {work_dir}.")
# parse raw sequences, optimized sequences, and loss metrics
raw_df = _read_sample_fastas(raw_seq_files, configs, "sequence")
optimized_df = _read_sample_fastas(optimized_seq_files, configs, "optimized_potts_sequence")
av_loss_df = _read_av_loss_files(av_loss_files, configs)
# prefer optimized sequences as output poses when present
if not optimized_df.empty and not raw_df.empty:
scores = raw_df.merge(optimized_df, on=["raw_description", "description", "sample_idx"], how="outer")
elif not optimized_df.empty:
scores = optimized_df
else:
scores = raw_df
if not av_loss_df.empty:
scores = scores.merge(av_loss_df, on=["raw_description", "description", "sample_idx"], how="left")
# write per-sequence FASTA files for RunnerOutput locations
fasta_output_dir = os.path.join(work_dir, "output_fastas")
os.makedirs(fasta_output_dir, exist_ok=True)
seq_col = "optimized_potts_sequence" if "optimized_potts_sequence" in scores.columns else "sequence"
locations = []
for _, row in scores.iterrows():
fasta_path = os.path.abspath(os.path.join(fasta_output_dir, f"{row['description']}.fa"))
with open(fasta_path, "w", encoding="UTF-8") as handle:
handle.write(f">{row['description']}\n{row[seq_col]}\n")
locations.append(fasta_path)
scores["location"] = locations
return scores
[docs]
def collect_scores_energy_prediction(work_dir: str, batched: bool, include_scores: list[str] | None = None) -> pd.DataFrame:
"""Collect ``energy_prediction.py`` CSV outputs into score rows.
PottsMPNN can emit many mutation rows per input pose. The collector stores
the full per-pose mutation table as JSON sidecars under ``output_scores``
and returns one ProtFlow score row per input structure.
Parameters
----------
work_dir : str
Runner work directory containing generated configs and PottsMPNN
outputs.
batched : bool
Whether the run used batched configs.
include_scores : list of str, optional
Reserved for API consistency; currently ignored.
Returns
-------
pandas.DataFrame
One row per input pose with the JSON sidecar path and number of scored
mutations.
Raises
------
FileNotFoundError
If no ``*_scores.csv`` files are found.
"""
# include_scores is unused because full per-pose CSV rows are stored sidecar-style
del include_scores
configs = _load_run_configs(work_dir)
out_dirs = _output_dirs_from_configs(configs, batched=batched)
score_files = _glob_output_files(out_dirs, "*_scores.csv")
if not score_files:
raise FileNotFoundError(f"No PottsMPNN energy prediction score files found under {work_dir}.")
# split combined score CSVs into per-pose JSON sidecars
output_dir = os.path.join(work_dir, "output_scores")
os.makedirs(output_dir, exist_ok=True)
rows = []
for score_file in score_files:
score_df = pd.read_csv(score_file)
stats_file = score_file.replace("_scores.csv", "_stats.csv")
for desc in score_df["pdb"].unique():
pose_df = score_df[score_df["pdb"] == desc]
pose_scorefile = os.path.abspath(os.path.join(output_dir, f"{desc}.json"))
pose_df.to_json(pose_scorefile, orient="records")
row = {
"description": desc,
"energy_prediction_scorefile": pose_scorefile,
"energy_prediction_n_mutations": len(pose_df.index),
}
if os.path.isfile(stats_file):
stats_df = pd.read_csv(stats_file)
pose_stats = stats_df[stats_df["pdb"] == desc]
if not pose_stats.empty:
row["energy_prediction_pearson_r"] = pose_stats.iloc[0].get("Pearson r")
rows.append(row)
return pd.DataFrame(rows)
[docs]
def collect_scores(work_dir: str, script: str, batched: bool, include_scores: list[str] | None = None) -> pd.DataFrame:
"""Dispatch score collection for the selected PottsMPNN script.
Parameters
----------
work_dir : str
Runner work directory.
script : str
Script alias used for the run.
batched : bool
Whether the run used batched configs.
include_scores : list of str, optional
Optional collector-specific score selection.
Returns
-------
pandas.DataFrame
Score rows produced by the script-specific collector.
Raises
------
NotImplementedError
If score collection is not implemented for ``script``.
"""
# map script aliases to collector functions
script_key = _script_key(script)
collectors = {
"sample_seqs": collect_scores_sample_seqs,
"energy_prediction": collect_scores_energy_prediction,
}
if script_key not in collectors:
raise NotImplementedError(f"Score collection is not implemented for PottsMPNN script: {script}")
return collectors[script_key](work_dir=work_dir, batched=batched, include_scores=include_scores)
def _script_key(script: str | None) -> str:
"""Normalize a script path or alias to its basename key."""
if not script:
return "sample_seqs"
return os.path.splitext(os.path.basename(str(script)))[0]
def _check_sample_descriptions(poses: Poses) -> None:
"""Reject input names that collide with sample output suffixes."""
# optimized output suffixes are reserved by upstream PottsMPNN
bad_suffixes = ("_optimized_potts", "_optimized_nodes")
bad = [desc for desc in poses.df["poses_description"].to_list() if str(desc).endswith(bad_suffixes)]
if bad:
raise ValueError(
"PottsMPNN sample output parsing reserves descriptions ending in "
f"{bad_suffixes}. Rename these poses first: {bad}"
)
def _fill_missing_locations(scores: pd.DataFrame, poses: Poses, index_layers: int) -> pd.DataFrame:
"""Map locationless score rows back to the input pose paths."""
# score collectors may omit locations when the input pose remains active
if "location" in scores.columns:
return scores
scores = scores.copy()
select_col = scores["description"].astype(str)
if index_layers:
select_col = select_col.str.split("_").str[:-index_layers].str.join("_")
pose_locations = poses.df.set_index("poses_description")["poses"]
scores["location"] = select_col.map(pose_locations)
if scores["location"].isna().any():
missing = scores.loc[scores["location"].isna(), "description"].to_list()
raise ValueError(f"Could not map PottsMPNN score rows back to input poses: {missing}")
return scores
def _split_pose_dataframe(poses: Poses, n_batches: int) -> list[pd.DataFrame]:
"""Split poses.df into up to n_batches dataframe chunks."""
# empty poses produce no config batches
if len(poses) == 0:
return []
batches = split_list(list(poses.df.index), n_sublists=max(1, n_batches or 1))
return [poses.df.loc[batch].reset_index(drop=True) for batch in batches]
def _iter_param_values(obj: Any, path: tuple[str, ...] = ()) -> list[tuple[tuple[str, ...], Any, bool]]:
"""Yield nested parameter paths, values, and batchability flags."""
# recurse through dataclass parameter containers
out = []
if is_dataclass(obj):
names = [param_field.name for param_field in fields(obj)]
batchable = set(getattr(obj, "batchable_params", []))
else:
return out
for name in names:
value = getattr(obj, name)
next_path = path + (name,)
if is_dataclass(value):
out.extend(_iter_param_values(value, next_path))
else:
out.append((next_path, value, name in batchable))
return out
def _params_to_dict(obj: Any, include_custom: bool) -> dict[str, Any]:
"""Convert parameter objects into YAML-serializable dictionaries."""
# unwrap dataclass fields
source = {param_field.name: getattr(obj, param_field.name) for param_field in fields(obj)}
out = {}
for key, value in source.items():
if not include_custom and key.endswith("_custom"):
continue
if is_dataclass(value):
out[key] = _params_to_dict(value, include_custom=include_custom)
elif isinstance(value, PoseCol):
out[key] = str(value)
else:
out[key] = value
return out
def _set_nested_attr(obj: Any, path: tuple[str, ...], value: Any) -> None:
"""Set a nested attribute on a parameter object."""
# walk to the owning nested object before assignment
target = obj
for name in path[:-1]:
target = getattr(target, name)
setattr(target, path[-1], value)
def _custom_path_to_json_path(path: tuple[str, ...]) -> tuple[str, ...]:
"""Convert a *_custom parameter path to its paired *_json path."""
# upstream expects JSON path fields, not ProtFlow custom helper fields
leaf = path[-1]
if not leaf.endswith("_custom"):
raise ValueError(f"Expected custom parameter path, got {'.'.join(path)}")
return path[:-1] + (f"{leaf[:-len('_custom')]}_json",)
def _write_lines(path: str, lines: list[str]) -> None:
"""Write text lines to a file, creating parent directories."""
# ensure parent directory exists before writing
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
with open(path, "w", encoding="UTF-8") as handle:
handle.write("\n".join(lines))
def _write_json(path: str, payload: dict[str, Any]) -> None:
"""Write JSON payload to a file, creating parent directories."""
# ensure parent directory exists before writing
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
with open(path, "w", encoding="UTF-8") as handle:
json.dump(payload, handle)
def _load_run_configs(work_dir: str) -> list[dict[str, Any]]:
"""Load generated PottsMPNN YAML configs from a work directory."""
# configs live in either unbatched config_files or batch_* directories
config_files = sorted(glob(os.path.join(work_dir, "config_files", "*.yaml")))
config_files += sorted(glob(os.path.join(work_dir, "batch_*", "config.yaml")))
configs = []
for config_file in config_files:
with open(config_file, "r", encoding="UTF-8") as handle:
cfg = yaml.safe_load(handle) or {}
cfg["_config_path"] = config_file
cfg["_input_descriptions"] = _read_input_descriptions(cfg.get("input_list"))
cfg["out_dir"] = os.path.abspath(cfg["out_dir"])
configs.append(cfg)
return configs
def _read_input_descriptions(input_list: str | None) -> list[str]:
"""Read base pose descriptions from a PottsMPNN input list."""
if not input_list:
return []
with open(input_list, "r", encoding="UTF-8") as handle:
return [line.strip().split("|", maxsplit=1)[0] for line in handle if line.strip()]
def _output_dirs_from_configs(configs: list[dict[str, Any]], batched: bool) -> list[str]:
"""Return unique output directories declared by generated configs."""
# output directories are encoded in YAML for both batched and unbatched runs
del batched
return sorted({cfg["out_dir"] for cfg in configs})
def _glob_output_files(out_dirs: list[str], pattern: str) -> list[str]:
"""Find output files matching a pattern across output directories."""
return sorted(path for out_dir in out_dirs for path in glob(os.path.join(out_dir, pattern)))
def _config_for_output_file(path: str, configs: list[dict[str, Any]]) -> dict[str, Any]:
"""Find the generated config that produced an output file."""
# match by output directory and configured out_name prefix
stem = os.path.splitext(os.path.basename(path))[0]
out_dir = os.path.abspath(os.path.dirname(path))
matches = []
for cfg in configs:
if os.path.abspath(cfg["out_dir"]) != out_dir:
continue
out_name = str(cfg["out_name"])
if stem == out_name or stem == f"{out_name}_av_loss" or stem.startswith(f"{out_name}_optimized_") or stem == f"{out_name}_scores":
matches.append(cfg)
if not matches:
raise ValueError(f"Could not match PottsMPNN output file to a generated config: {path}")
return max(matches, key=lambda cfg: len(str(cfg["out_name"])))
def _canonical_sample_description(raw_description: str, input_descriptions: list[str]) -> tuple[str, int]:
"""Convert PottsMPNN sample names to ProtFlow merge descriptions."""
# prefer the longest input description to handle names containing underscores
input_descriptions = sorted(input_descriptions, key=len, reverse=True)
for input_desc in input_descriptions:
if raw_description == input_desc:
return f"{input_desc}_0001", 1
prefix = f"{input_desc}_"
if raw_description.startswith(prefix):
suffix = raw_description[len(prefix):]
if suffix.isdigit():
sample_idx = int(suffix) + 1
return f"{input_desc}_{sample_idx:04d}", sample_idx
if raw_description.rsplit("_", maxsplit=1)[-1].isdigit():
base, idx = raw_description.rsplit("_", maxsplit=1)
return f"{base}_{int(idx) + 1:04d}", int(idx) + 1
return f"{raw_description}_0001", 1
def _read_sample_fastas(files: list[str], configs: list[dict[str, Any]], seq_col: str) -> pd.DataFrame:
"""Read sample or optimized FASTA outputs into normalized rows."""
# normalize upstream sample names before merging into poses
rows = []
for fasta_file in files:
cfg = _config_for_output_file(fasta_file, configs)
df = fasta_to_df(fasta_file, desc_col_name="raw_description", seq_col_name=seq_col)
for _, row in df.iterrows():
description, sample_idx = _canonical_sample_description(row["raw_description"], cfg["_input_descriptions"])
rows.append({
"raw_description": row["raw_description"],
"description": description,
"sample_idx": sample_idx,
seq_col: row[seq_col],
})
return pd.DataFrame(rows)
def _read_av_loss_files(files: list[str], configs: list[dict[str, Any]]) -> pd.DataFrame:
"""Read average-loss CSV files into normalized score rows."""
# keep upstream metrics while replacing pdb with normalized descriptions
rows = []
for av_loss_file in files:
cfg = _config_for_output_file(av_loss_file, configs)
df = pd.read_csv(av_loss_file)
for _, row in df.iterrows():
raw_description = row["pdb"]
description, sample_idx = _canonical_sample_description(raw_description, cfg["_input_descriptions"])
out = row.to_dict()
out["raw_description"] = raw_description
out["description"] = description
out["sample_idx"] = sample_idx
rows.append(out)
return pd.DataFrame(rows)