Source code for protflow.metrics.generic_metric_runner

"""
Generic metric runner for ProtFlow.

This module exposes :class:`GenericMetric`, a lightweight :class:`protflow.runners.Runner`
that executes any importable Python function over the poses stored in a
:class:`protflow.poses.Poses` object. The target function must accept a single
pose path as its first positional argument and return a JSON-serializable value.
Additional keyword arguments can be forwarded through the runner's ``options``
dictionary.

How it works
------------
``GenericMetric.run()`` resolves the working directory and jobstarter, splits
``poses.poses_list()`` into manageable chunks, and starts one worker command
per chunk. Each worker re-enters this module as a small CLI program, imports
the requested module and function dynamically, evaluates the function on every
pose path in its chunk, and stores the results as JSON. The parent process then
concatenates the worker outputs and merges them back into ``poses.df`` through
``RunnerOutput``.

Walkthrough
-----------
The example below calculates the radius of gyration for every pose by reusing
``protflow.utils.metrics.calc_rog_of_pdb``:

.. code-block:: python

    from protflow.poses import Poses
    from protflow.jobstarters import SbatchArrayJobstarter
    from protflow.metrics.generic_metric_runner import GenericMetric

    poses = Poses(
        poses=["/data/designs/design_0001.pdb", "/data/designs/design_0002.pdb"],
        work_dir="/data/protflow_runs"
    )
    cpu_jobstarter = SbatchArrayJobstarter(max_cores=10)

    rog = GenericMetric(
        module="protflow.utils.metrics",
        function="calc_rog_of_pdb",
        options={"chain": "A"},
        jobstarter=cpu_jobstarter,
    )

    poses = rog.run(poses=poses, prefix="rog")

    # GenericMetric stores the returned value in <prefix>_data.
    print(poses.df[["poses_description", "rog_data"]])

In that run, ``GenericMetric`` will:

1. Build ``/data/protflow_runs/rog`` as its working directory.
2. Split the input pose paths into chunks based on ``max_cores`` and a hard
   limit of 100 poses per command.
3. Launch worker commands that call ``calc_rog_of_pdb(pose_path, chain="A")``.
4. Save intermediate JSON files such as ``out_0.json``.
5. Merge the combined results back into ``poses.df`` as
   ``rog_data``, ``rog_description``, and ``rog_location``.

This module is intended for simple, embarrassingly parallel per-pose metrics.
If your function needs multiple inputs, non-JSON output, or a richer output
schema than a single ``data`` column, a dedicated runner is usually a better
fit.
"""

# import general
import os
import json
import logging
import importlib

# import dependencies
import pandas as pd

# import customs
from protflow.poses import Poses
from protflow.runners import Runner, RunnerOutput
from protflow import load_config_path, require_config
from protflow.jobstarters import JobStarter, split_list

[docs] class GenericMetric(Runner): """ Run a simple Python metric function over every pose in a :class:`Poses`. ``GenericMetric`` is the most lightweight metric runner in ProtFlow. You point it at an importable module and a function name, optionally provide a shared ``options`` dictionary, and the runner takes care of chunking the pose list, dispatching jobs through a :class:`JobStarter`, collecting the JSON outputs, and merging the results back into ``poses.df``. The target function contract is intentionally small: - The first positional argument must be the pose path. - Optional keyword arguments can be supplied via ``options``. - The return value must be serializable to JSON. The resulting metric value is stored in ``<prefix>_data`` after the run is merged back into ``poses.df``. """
[docs] def __init__(self, python_path: str|None = None, module: str = None, function: str = None, options: dict = None, jobstarter: JobStarter = None, overwrite: bool = False): # pylint: disable=W0102 """ Initialize a generic per-pose metric runner. Parameters ---------- python_path : str | None, optional Python interpreter used to launch worker commands. If omitted, the interpreter from the configured ``PROTFLOW_ENV`` is used. module : str | None, optional Importable module path that contains the target metric function. function : str | None, optional Name of the function to call inside ``module``. options : dict | None, optional Keyword arguments forwarded to the target function for every pose. jobstarter : JobStarter | None, optional Default jobstarter used when ``run()`` is called without one. overwrite : bool, optional Whether existing runner scorefiles should be recomputed by default. """ # setup config config = require_config() self.set_python_path(python_path or os.path.join(load_config_path(config, "PROTFLOW_ENV"), "python")) # setup runner self.set_module(module) self.set_function(function) self.set_jobstarter(jobstarter) self.set_options(options) self.overwrite = overwrite
def __str__(self): return "GenericMetric" ########################## Input ################################################
[docs] def set_module(self, module: str) -> None: """ Set the importable module path that contains the metric function. Parameters ---------- module : str Importable module path, for example ``"protflow.utils.metrics"``. """ self.module = module
[docs] def set_python_path(self, python_path: str) -> None: """Set the Python interpreter used for worker execution.""" self.python_path = python_path
[docs] def set_function(self, function: str) -> None: """ Set the function name to import from ``self.module``. Parameters ---------- function : str Attribute name of the target metric function. """ self.function = function
[docs] def set_jobstarter(self, jobstarter: JobStarter) -> None: """ Set the default jobstarter for this runner instance. Parameters ---------- jobstarter : JobStarter | None Jobstarter used when ``run()`` does not receive one explicitly. Raises ------ ValueError If ``jobstarter`` is neither ``None`` nor a :class:`JobStarter`. """ if isinstance(jobstarter, JobStarter) or jobstarter is None: self.jobstarter = jobstarter else: raise ValueError(f"Parameter :jobstarter: must be of type JobStarter. type(jobstarter= = {type(jobstarter)})")
[docs] def set_options(self, options: dict) -> None: """ Set shared keyword arguments for the metric function. Parameters ---------- options : dict | None Keyword arguments forwarded as ``function(pose, **options)``. Raises ------ ValueError If ``options`` is neither ``None`` nor a dictionary. """ if isinstance(options, dict) or options is None: self.options = options else: raise ValueError(f"Parameter :options: must be of type dict. type(options= = {type(options)})")
########################## Calculations ################################################
[docs] def run(self, poses: Poses, prefix: str, python_path: str = None, module: str = None, function: str = None, options: dict = None, jobstarter: JobStarter = None, overwrite: bool = False) -> Poses: """ Execute the configured metric function across all poses. Parameters ---------- poses : Poses Input poses. ``GenericMetric`` reads the pose file paths from ``poses.df["poses"]``. prefix : str Prefix used for the runner work directory, cached scorefile, and merged result columns. python_path : str | None, optional Python interpreter used for worker commands. Defaults to the value configured on the runner instance. module : str | None, optional Importable module path for the metric function. Defaults to the value configured on the runner instance. function : str | None, optional Function name inside ``module``. Defaults to the value configured on the runner instance. options : dict | None, optional Shared keyword arguments forwarded to the metric function. Defaults to the value configured on the runner instance. jobstarter : JobStarter | None, optional Jobstarter used for this invocation. Resolution priority is ``run(jobstarter)`` -> ``self.jobstarter`` -> ``poses.default_jobstarter``. overwrite : bool, optional If ``True``, recompute the metric even when the cached scorefile already exists. Returns ------- Poses The input ``Poses`` instance with additional columns such as ``<prefix>_data``, ``<prefix>_description``, and ``<prefix>_location`` merged into ``poses.df``. Raises ------ ValueError If ``options`` is not a dictionary or if no usable jobstarter is available. RuntimeError If fewer output rows are collected than input poses, which usually indicates failed worker jobs. Examples -------- .. code-block:: python from protflow.metrics.generic_metric_runner import GenericMetric rog = GenericMetric( module="protflow.utils.metrics", function="calc_rog_of_pdb", options={"chain": "A"}, ) poses = rog.run(poses=poses, prefix="rog", jobstarter=cpu_jobstarter) Notes ----- Internally, ``run()`` launches this module as a worker script for each pose chunk. Each worker writes a JSON file with the columns ``data``, ``description``, and ``location``. The parent process concatenates those files and lets :class:`RunnerOutput` merge the final table back into ``poses.df``. """ # if self.atoms is all, calculate Allatom RMSD. # prep variables work_dir, jobstarter = self.generic_run_setup( poses=poses, prefix=prefix, jobstarters=[jobstarter, self.jobstarter, poses.default_jobstarter] ) python_path = python_path or self.python_path module = module or self.module function = function or self.function options = options or self.options if not (isinstance(options, dict) or options is None): raise ValueError(f"Parameter :options: must be of type dict. type(options= = {type(options)})") logging.info(f"Running metric {function} of module {module} in {work_dir} on {len(poses.df.index)} poses.") scorefile = os.path.join(work_dir, f"{prefix}_{function}_generic_metric.{poses.storage_format}") # check if RMSD was calculated if overwrite was not set. overwrite = overwrite or self.overwrite if (scores := self.check_for_existing_scorefile(scorefile=scorefile, overwrite=overwrite)) is not None: logging.info(f"Found existing scorefile at {scorefile}. Returning {len(scores.index)} poses from previous run without running calculations.") output = RunnerOutput(poses=poses, results=scores, prefix=prefix) return output.return_poses() # split poses into number of max_cores lists, but not more than 100 poses per sublist (otherwise, argument list too long error occurs) poses_sublists = split_list(poses.poses_list(), n_sublists=jobstarter.max_cores) if len(poses.df.index) / jobstarter.max_cores < 100 else split_list(poses.poses_list(), element_length=100) out_files = [os.path.join(poses.work_dir, prefix, f"out_{index}.json") for index, sublist in enumerate(poses_sublists)] cmds = [f"{python_path} {__file__} --poses {','.join(poses_sublist)} --out {out_file} --module {module} --function {function}" for out_file, poses_sublist in zip(out_files, poses_sublists)] if options: options_path = os.path.join(poses.work_dir, prefix, f"{prefix}_options.json") with open(options_path, "w", encoding="UTF-8") as f: json.dump(options, f) cmds = [f"{cmd} --options {options_path}" for cmd in cmds] # run command jobstarter.start( cmds = cmds, jobname = f"{function}_generic_metric", output_path = work_dir ) # collect individual DataFrames into one scores = pd.concat([pd.read_json(output) for output in out_files]).reset_index(drop=True) if len(scores.index) < len(poses.df.index): raise RuntimeError("Number of output poses is smaller than number of input poses. Some runs might have crashed!") logging.info(f"Saving scores of generic metric runner with function {function} at {scorefile}.") self.save_runner_scorefile(scores=scores, scorefile=scorefile) # create standardised output for poses class: output = RunnerOutput( poses = poses, results = scores, prefix = prefix, ) logging.info(f"{function} completed. Returning scores.") return output.return_poses()
[docs] def main(args): """Worker entrypoint used by :meth:`GenericMetric.run`. The parent runner starts this module as a CLI script, passes a comma- separated list of pose paths plus the import target, and expects a JSON file containing ``data``, ``description``, and ``location`` columns. """ input_poses = args.poses.split(",") # import function module_ = importlib.import_module(args.module) function = getattr(module_, args.function) # calculate data if args.options: with open(args.options, "r", encoding="UTF-8") as f: options = json.load(f) data = [function(pose, **options) for pose in input_poses] else: data = [function(pose) for pose in input_poses] description = [os.path.splitext(os.path.basename(pose))[0] for pose in input_poses] location = list(input_poses) # create results dataframe results = pd.DataFrame({"data": data, "description": description, "location": location}) # save output results.to_json(args.out)
if __name__ == "__main__": import argparse argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) argparser.add_argument("--poses", type=str, required=True, help="input_directory that contains all ensemble *.pdb files to be hallucinated (max 1000 files).") argparser.add_argument("--out", type=str, required=True, help="input_directory that contains all ensemble *.pdb files to be hallucinated (max 1000 files).") argparser.add_argument("--module", type=str, required=True, help="input_directory that contains all ensemble *.pdb files to be hallucinated (max 1000 files).") argparser.add_argument("--function", type=str, required=True, help="input_directory that contains all ensemble *.pdb files to be hallucinated (max 1000 files).") argparser.add_argument("--options", type=str, default=None, help="input_directory that contains all ensemble *.pdb files to be hallucinated (max 1000 files).") arguments = argparser.parse_args() main(arguments)