"""
We have chosen to use `Snakemake <https://snakemake.readthedocs.io/en/stable/>`_
as the EasyLink workflow manager. This module is responsible for generating the
Snakemake rules to be run as well as writing them to the Snakefile.
Note we have adopted the Snakemake term "rule" to refer to a singular component
in a Snakefile (i.e. in a Snakemake pipeline) that defines input files, output files,
and the command to run to create those output files. These rules are generated
dynamically as strings and appended to the Snakefile.
"""
import os
import shlex
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
[docs]
class Rule(ABC):
"""An abstract class used to generate Snakemake rules."""
[docs]
def write_to_snakefile(self, snakefile_path) -> None:
"""Writes the rule to the Snakefile.
Parameters
----------
snakefile_path
Path to the Snakefile to write the rule to.
"""
with open(snakefile_path, "a") as f:
f.write(self.build_rule())
[docs]
@abstractmethod
def build_rule(self) -> str:
"""Builds the snakemake rule to be written to the Snakefile.
This is an abstract method and must be implemented by concrete instances.
"""
pass
[docs]
@dataclass
class TargetRule(Rule):
"""A :class:`Rule` that defines the final output of the pipeline.
Snakemake will determine the directed acyclic graph (DAG) based on this target.
"""
target_files: list[str]
"""List of final output filepaths."""
validation: str
"""Name of file created by :class:`InputValidationRule`."""
requires_spark: bool
"""Whether or not this rule requires a Spark environment to run."""
[docs]
def build_rule(self) -> str:
"""Builds the Snakemake rule for the final output of the pipeline."""
outputs = [os.path.basename(file_path) for file_path in self.target_files]
rulestring = f"""
rule all:
message: 'Grabbing final output'
localrule: True
input:
final_output={self.target_files},
validation='{self.validation}',"""
if self.requires_spark:
rulestring += f"""
term="spark_logs/spark_master_terminated.txt",
master_log="spark_logs/spark_master_log.txt",
worker_logs=gather.num_workers("spark_logs/spark_worker_log_{{scatteritem}}.txt",
),"""
rulestring += f"""
output: {outputs}
run:
import os
for input_path, output_path in zip(input.final_output, output):
os.symlink(input_path, output_path)"""
return rulestring
[docs]
@dataclass
class ImplementedRule(Rule):
"""A :class:`Rule` that defines the execution of an :class:`~easylink.implementation.Implementation`."""
name: str
"""Name of the rule."""
step_name: str
"""Name of the step this rule is implementing."""
implementation_name: str
"""Name of the ``Implementation`` to build the rule for."""
input_slots: dict[str, dict[str, str | list[str]]]
"""This ``Implementation's`` input slot attributes."""
validations: list[str]
"""Names of files created by :class:`InputValidationRule`. These files are empty
but used by Snakemake to build the graph edges of dependency on validation rules."""
output: list[str]
"""Output data filepaths."""
resources: dict | None
"""Computational resources used by executor (e.g. SLURM)."""
envvars: dict
"""Environment variables to set."""
diagnostics_dir: str
"""Directory for diagnostic files."""
image_path: str
"""Path to the Singularity image to run."""
script_cmd: str
"""Command to execute."""
requires_spark: bool
"""Whether or not this ``Implementation`` requires a Spark environment."""
is_auto_parallel: bool = False
"""Whether or not this ``Implementation`` is to be automatically run in parallel."""
[docs]
def build_rule(self) -> str:
"""Builds the Snakemake rule for this ``Implementation``."""
if self.is_auto_parallel and len(self.output) > 1:
raise NotImplementedError(
"Multiple output slots/files of AutoParallelSteps not yet supported"
)
return self._build_io() + self._build_resources() + self._build_shell_cmd()
[docs]
def _build_io(self) -> str:
"""Builds the input/output portion of the rule."""
log_path_chunk_adder = "-{chunk}" if self.is_auto_parallel else ""
# Handle output files vs directories
files = [path for path in self.output if Path(path).suffix != ""]
if len(files) == len(self.output):
output = self.output
elif len(files) == 0:
if len(self.output) != 1:
raise NotImplementedError("Multiple output directories is not supported.")
output = f"directory('{self.output[0]}')"
else:
raise NotImplementedError(
"Mixed output types (files and directories) is not supported."
)
io_str = (
f"""
rule:
name: "{self.name}"
message: "Running {self.step_name} implementation: {self.implementation_name}" """
+ self._build_input()
+ f"""
output: {output}
log: "{self.diagnostics_dir}/{self.name}-output{log_path_chunk_adder}.log"
container: "{self.image_path}" """
)
return io_str
[docs]
def _build_resources(self) -> str:
"""Builds the resources portion of the rule."""
if not self.resources:
return ""
return f"""
resources:
slurm_partition={self.resources['slurm_partition']},
mem_mb={self.resources['mem_mb']},
runtime={self.resources['runtime']},
cpus_per_task={self.resources['cpus_per_task']},
slurm_extra="--output '{self.diagnostics_dir}/{self.name}-slurm-%j.log'" """
[docs]
def _build_shell_cmd(self) -> str:
"""Builds the shell command portion of the rule."""
# TODO [MIC-5787]: handle multiple wildcards, e.g.
# output_paths = ",".join(self.output)
# wildcards_subdir = "/".join([f"{{wildcards.{wc}}}" for wc in self.wildcards])
# and then in shell cmd: export OUTPUT_PATHS={output_paths}/{wildcards_subdir}
# snakemake shell commands require wildcards to be prefaced with 'wildcards.'
output_files = ",".join(self.output).replace("{chunk}", "{wildcards.chunk}")
shell_cmd = f"""
shell:
'''
export OUTPUT_PATHS={output_files}
export DIAGNOSTICS_DIRECTORY={self.diagnostics_dir}"""
for input_slot_attrs in self.input_slots.values():
# snakemake shell commands require wildcards to be prefaced with 'wildcards.'
input_files = ",".join(input_slot_attrs["filepaths"]).replace(
"{chunk}", "{wildcards.chunk}"
)
shell_cmd += f"""
export {input_slot_attrs["env_var"]}={input_files}"""
if self.requires_spark:
shell_cmd += f"""
read -r SPARK_MASTER_URL < {{input.master_url}}
export SPARK_MASTER_URL"""
for var_name, var_value in self.envvars.items():
shell_cmd += f"""
export {var_name}={shlex.quote(str(var_value))}"""
# Log stdout/stderr to diagnostics directory
shell_cmd += f"""
{self.script_cmd} > {{log}} 2>&1
'''"""
return shell_cmd
[docs]
@dataclass
class CheckpointRule(Rule):
"""A :class:`Rule` that defines a checkpoint.
When running an :class:`~easylink.implementation.Implementation` in an auto
parallel way, we do not know until runtime how many parallel jobs there will
be (e.g. we don't know beforehand how many chunks a large incoming dataset will
be split into since the incoming dataset isn't created until runtime). The
snakemake mechanism to handle this dynamic nature is a
`checkpoint <https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#data-dependent-conditional-execution/>`_
rule along with a directory as output.
Notes
-----
There is a known `Snakemake bug <https://github.com/snakemake/snakemake/issues/3036>`_
which prevents the use of multiple checkpoints in a single Snakefile. We
work around this by generating an empty checkpoint.txt file as part of this
rule. If this file does not yet exist when trying to run the :class:`AggregationRule`,
it means that the checkpoint has not yet been executed for the
particular wildcard value(s). In this case, we manually raise a Snakemake
``IncompleteCheckpointException`` which Snakemake automatically handles
and leads to a re-evaluation after the checkpoint has successfully passed.
TODO [MIC-5658]: Thoroughly test this workaround when implementing cacheing.
"""
name: str
"""Name of the rule."""
input_files: list[str]
"""The input filepaths."""
splitter_func_name: str
"""The splitter function's name."""
output_dir: str
"""Output directory path. It must be used as an input for next rule."""
checkpoint_filepath: str
"""Path to the checkpoint file. This is only needed for the bugfix workaround."""
[docs]
def build_rule(self) -> str:
"""Builds the Snakemake rule for this checkpoint.
Checkpoint rules are a special type of rule in Snakemake that allow for dynamic
generation of output files. This rule is responsible for splitting the input
files into chunks. Note that the output of this rule is a Snakemake ``directory``
object as opposed to a specific file like typical rules have.
"""
checkpoint = f"""
checkpoint:
name: "{self.name}"
input:
files={self.input_files},
output:
output_dir=directory("{self.output_dir}"),
checkpoint_file=touch("{self.checkpoint_filepath}"),
params:
input_files=lambda wildcards, input: ",".join(input.files),
localrule: True
message: "Splitting {self.name} into chunks"
run:
splitter_utils.{self.splitter_func_name}(
input_files=list(input.files),
output_dir=output.output_dir,
desired_chunk_size_mb=0.1,
)"""
return checkpoint
[docs]
@dataclass
class AggregationRule(Rule):
"""A :class:`Rule` that aggregates the processed chunks of output data.
When running an :class:`~easylink.implementation.Implementation` in an auto
parallel way, we need to aggregate the output files from each parallel job
into a single output file.
"""
name: str
"""Name of the rule."""
input_files: list[str]
"""The input processed chunk files to aggregate."""
aggregated_output_file: str
"""The final aggregated results file."""
aggregator_func_name: str
"""The name of the aggregation function to run."""
checkpoint_filepath: str
"""Path to the checkpoint file. This is only needed for the bugfix workaround."""
checkpoint_rule_name: str
"""Name of the checkpoint rule."""
[docs]
def build_rule(self) -> str:
"""Builds the Snakemake rule for this aggregator.
When running an :class:`~easylink.step.AutoParallelStep`, we need
to aggregate the output files from each parallel job into a single output file.
This rule relies on a dynamically generated aggregation function which returns
all of the **processed** chunks (from running the ``AutoParallelStep's``
container in parallel) and uses them as inputs to the actual aggregation
rule.
Notes
-----
There is a known `Snakemake bug <https://github.com/snakemake/snakemake/issues/3036>`_
which prevents the use of multiple checkpoints in a single Snakefile. We
work around this by generating an empty checkpoint.txt file in the
:class:`~CheckpointRule`. If this file does not yet exist when trying to
aggregate, it means that the checkpoint has not yet been executed for the
particular wildcard value(s). In this case, we manually raise a Snakemake
``IncompleteCheckpointException`` which `Snakemake automatically handles
<https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#data-dependent-conditional-execution>`_
and leads to a re-evaluation after the checkpoint has successfully passed,
i.e. we replicate `Snakemake's behavior <https://github.com/snakemake/snakemake/blob/04f89d330dd94baa51f41bc796392f85bccbd231/snakemake/checkpoints.py#L42>`_.
"""
input_function = self._define_input_function()
rule = self._define_aggregator_rule()
return input_function + rule
[docs]
def _define_aggregator_rule(self):
"""Builds the rule that runs the aggregation."""
rule = f"""
rule:
name: "{self.name}"
input: get_aggregation_inputs_{self.name}
output: {[self.aggregated_output_file]}
localrule: True
message: "Aggregating {self.name}"
run:
aggregator_utils.{self.aggregator_func_name}(
input_files=list(input),
output_filepath="{self.aggregated_output_file}",
)"""
return rule