Source code for quantum_executor.executor

"""The QuantumExecutor orchestrates quantum job splitting, dispatching, execution, and (optionally) result merging."""

import importlib.util
import logging
import threading
from collections.abc import Callable
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Union

from quantum_executor.dispatch import Dispatch
from quantum_executor.job_runner import run_single_job_static
from quantum_executor.result_collector import MergedResultCollector
from quantum_executor.result_collector import ResultCollector
from quantum_executor.virtual_provider import VirtualProvider

if TYPE_CHECKING:  # pragma: no cover
    from quantum_executor.dispatch import DispatchDict

logger = logging.getLogger(__name__)


[docs] def load_policies_from_folder( # pylint: disable=too-many-branches folder_path: str, raise_exc: bool = False ) -> dict[str, dict[str, Callable[..., Any]]]: """Dynamically load split and/or merge policies from Python files in a folder. Parameters ---------- folder_path : str Path to the folder containing policy files. raise_exc : bool, optional If True, raise on missing folder or load errors. Otherwise, log a warning. Defaults to False. Returns ------- Dict[str, Dict[str, Callable[..., Any]]] Mapping policy names → dict with keys "split" and/or "merge". """ logger.debug("Loading policies from folder '%s'...", folder_path) if not Path.exists(Path(folder_path)): if raise_exc: raise FileNotFoundError(f"Folder '{folder_path}' does not exist.") logger.warning("Folder '%s' does not exist; creating it.", folder_path) Path.mkdir(Path(folder_path), parents=True, exist_ok=True) return {} policies: dict[str, dict[str, Callable[..., Any]]] = {} for fname in Path(folder_path).iterdir(): if not fname.is_file() or not fname.name.endswith(".py"): continue name = fname.name[:-3] path = folder_path / fname spec = importlib.util.spec_from_file_location(name, path) if spec is None or spec.loader is None: msg = f"Cannot load module '{name}' from '{path}'." if raise_exc: raise ImportError(msg) logger.warning(msg) continue module = importlib.util.module_from_spec(spec) try: spec.loader.exec_module(module) except Exception as e: # pylint: disable=broad-except msg = f"Error loading '{fname}': {e}" if raise_exc: raise ImportError(msg) from e logger.warning(msg) continue funcs: dict[str, Callable[..., Any]] = {} if hasattr(module, "split") and callable(module.split): funcs["split"] = module.split if hasattr(module, "merge") and callable(module.merge): funcs["merge"] = module.merge if funcs: policies[name] = funcs else: msg = f"Policy file '{fname}' defines neither split nor merge." if raise_exc: raise ImportError(msg) logger.warning(msg) return policies
[docs] def add_policy_from_file(file_path: str, policy_folder: str, raise_exc: bool = False) -> None: """Load a policy module from file and copy it into the policies folder. Parameters ---------- file_path : str Path to the Python file containing the policy. policy_folder : str Folder where policies live. raise_exc : bool, optional If True, raise on errors. Otherwise, log warnings. Defaults to False. """ spec = importlib.util.spec_from_file_location("policy_mod", file_path) if spec is None or spec.loader is None: msg = f"Cannot load module from '{file_path}'." if raise_exc: raise ImportError(msg) logger.warning(msg) return module = importlib.util.module_from_spec(spec) try: spec.loader.exec_module(module) except Exception as e: # pylint: disable=broad-except msg = f"Error importing policy from '{file_path}': {e}" if raise_exc: raise ImportError(msg) from e logger.warning(msg) return has_split = hasattr(module, "split") and callable(module.split) has_merge = hasattr(module, "merge") and callable(module.merge) if not (has_split or has_merge): msg = f"Policy file '{file_path}' defines neither split nor merge." if raise_exc: raise ImportError(msg) logger.warning(msg) return Path(policy_folder).mkdir(parents=True, exist_ok=True) dest = Path(policy_folder) / Path(file_path).name with ( Path(file_path).open(encoding="utf-8") as src, dest.open("w", encoding="utf-8") as dst, ): if module.__doc__: dst.write(module.__doc__ + "\n\n") dst.write(src.read()) logger.info("Policy '%s' copied to '%s'.", Path(file_path).name, policy_folder)
[docs] class QuantumExecutor: """Manage splitting, dispatching, execution, and optional merging of quantum jobs. Parameters ---------- providers_info : Dict[str, Dict[str, Any]], optional A dictionary mapping provider names to their respective API keys or configuration. The keys should be the provider names (e.g., "ionq", "azure") and the values should be dictionaries containing the necessary parameters for initialization. For example: { "ionq": {"api_key": "your-api-key-here"}, "azure": {"subscription_id": "your-subscription-id", "resource_group": "your-resource-group"}, "braket": {"aws_access_key_id": "your-access-key-id", "aws_secret_access_key": "your-secret-access-key"}, "local_aer": {}, "qbraid": {"api_key": "your-api-key-here"}, } providers : List[str], optional Which providers to initialize. policies_folder : str, optional Where to load/save policy files. max_workers : int, optional Max processes for async execution. raise_exc : bool, optional If True, propagate initialization or policy-load errors. virtual_provider : VirtualProvider, optional If provided, use this instead of creating a new one. """ _default_split = "uniform" def __init__( # pylint: disable=too-many-arguments too-many-positional-arguments self, providers_info: dict[str, dict[str, Any]] | None = None, providers: list[str] | None = None, policies_folder: str = __file__.replace("executor.py", "policies"), max_workers: int | None = None, raise_exc: bool = False, virtual_provider: VirtualProvider | None = None, ) -> None: """Manage splitting, dispatching, execution, and optional merging of quantum jobs. Parameters ---------- providers_info : Dict[str, Dict[str, Any]], optional A dictionary mapping provider names to their respective API keys or configuration. The keys should be the provider names (e.g., "ionq", "azure") and the values should be dictionaries containing the necessary parameters for initialization. For example: { "ionq": {"api_key": "your-api-key-here"}, "azure": {"subscription_id": "your-subscription-id", "resource_group": "your-resource-group"}, "braket": {"aws_access_key_id": "your-access-key-id", "aws_secret_access_key": "your-secret-access-key"}, "local_aer": {}, "qbraid": {"api_key": "your-api-key-here"}, } providers : List[str], optional Which providers to initialize. policies_folder : str, optional Where to load/save policy files. max_workers : int, optional Max processes for async execution. raise_exc : bool, optional If True, propagate initialization or policy-load errors. virtual_provider : VirtualProvider, optional If provided, use this instead of creating a new one. """ self._policies_folder = policies_folder self._max_workers = max_workers self._raise_exc = raise_exc if virtual_provider is None: self._providers_info = providers_info or {} self._providers = providers self._virtual_provider = VirtualProvider( providers_info=self._providers_info, include=self._providers, raise_exc=self._raise_exc, ) else: self._providers_info = virtual_provider._providers_info self._providers = list(virtual_provider.get_providers().keys()) self._virtual_provider = virtual_provider self._policies = load_policies_from_folder(self._policies_folder, raise_exc=self._raise_exc) logger.info("QuantumExecutor initialized.")
[docs] def generate_dispatch( # pylint: disable=too-many-positional-arguments too-many-arguments too-many-locals self, circuits: Any | Sequence[Any], # noqa: ANN401 shots: int | Sequence[int], backends: dict[str, list[str]], split_policy: str = _default_split, split_data: dict[str, Any] | None = None, ) -> tuple[Dispatch, dict[str, Any]]: """Split a circuit into jobs based on the specified split policy. Parameters ---------- circuits : Any or Sequence[Any] Quantum circuit or list of quantum circuits. shots : int or Sequence[int] Number of shots or list of numbers of shots. If a list, it must match the length of `circuits`. If a single int, all circuits will use the same number of shots. backends : dict[str, list[str]] Provider → list of backends. split_policy : str, optional Which split policy to use. split_data : dict, optional Initial data for split policy. Returns ------- tuple[Dispatch, dict[str, Any]] A Dispatch object containing the jobs and any updated split data. """ if isinstance(circuits, Sequence): circuits = list(circuits) shots_list = [shots] * len(circuits) if isinstance(shots, int) else list(shots) if len(shots_list) != len(circuits): raise ValueError( "When passing multiple circuits, shots must be a single int or a list of the same length." ) split_fn = self.get_split_policy(split_policy) aggregated: dict[str, dict[str, list[Any]]] = {} split_data = split_data or {} for circ, sh in zip(circuits, shots_list, strict=False): disp_i, updated_split_data = split_fn(circ, sh, backends, self._virtual_provider, split_data) split_data = updated_split_data for prov, back_map in disp_i.to_dict().items(): agg_back_map = aggregated.setdefault(prov, {}) for back, jobs in back_map.items(): agg_back_map.setdefault(back, []).extend(jobs) return Dispatch(aggregated), split_data or {} if isinstance(shots, Sequence) and len(shots) > 1: raise ValueError("When passing a single circuit, shots must be a single int, not a list.") if isinstance(shots, Sequence): shots = shots[0] # Single-circuit path split_fn = self.get_split_policy(split_policy) split_data = split_data or {} return split_fn( # type: ignore[no-any-return] circuits, shots, backends, self._virtual_provider, split_data, )
[docs] def run_experiment( # pylint: disable=too-many-positional-arguments too-many-arguments self, circuits: Any | Sequence[Any], # noqa: ANN401 shots: int | Sequence[int], backends: dict[str, list[str]], split_policy: str = _default_split, merge_policy: str | None = None, multiprocess: bool = False, wait: bool = True, split_data: dict[str, Any] | None = None, merge_data: dict[str, Any] | None = None, max_workers: int | None = None, ) -> ResultCollector | MergedResultCollector: """Split a circuit into jobs, dispatch them, and optionally merge results. Parameters ---------- circuits : Any or Sequence[Any] Quantum circuit or list of quantum circuits. shots : int or Sequence[int] Number of shots or list of numbers of shots. If a list, it must match the length of `circuits`. If a single int, all circuits will use the same number of shots. backends : dict[str, list[str]] Provider → list of backends. split_policy : str, optional Which split policy to use. merge_policy : str or None, optional Which merge policy to use; if None, skip merging. multiprocess : bool, optional If True, run jobs in parallel processes. wait : bool, optional If True, block until execution (and merge) finishes. split_data : dict, optional Initial data for split policy. merge_data : dict, optional Initial data for merge policy, if None, use updated split data. max_workers : int, optional Override for max parallel processes. Returns ------- ResultCollector or MergedResultCollector Unmerged collector if `merge_policy` is None, otherwise merged. """ logger.info( "Experiment start: split_policy='%s', merge_policy='%s'", split_policy, merge_policy, ) dispatch_obj, updated_split = self.generate_dispatch( circuits=circuits, shots=shots, backends=backends, split_policy=split_policy, split_data=split_data, ) return self.run_dispatch( dispatch=dispatch_obj, multiprocess=multiprocess, wait=wait, max_workers=max_workers, merge_policy=merge_policy, merge_data=merge_data or updated_split, )
# pylint: disable=too-many-positional-arguments too-many-arguments too-many-locals too-many-branches
[docs] def run_dispatch( # pylint: disable=too-many-statements self, dispatch: Union[Dispatch, "DispatchDict"], multiprocess: bool = False, wait: bool = True, max_workers: int | None = None, merge_policy: str | None = None, merge_data: dict[str, Any] | None = None, ) -> ResultCollector | MergedResultCollector: """Execute all jobs in a Dispatch and optionally merge their results. Parameters ---------- dispatch : Dispatch or DispatchDict Jobs to execute. multiprocess : bool, optional If True, run in parallel processes. wait : bool, optional If True, block until execution (and merge) finishes. If False, jobs run in a background thread (sequentially if multiprocess=False, or gathering + merging in threads if multiprocess=True). max_workers : int, optional Override for max parallel processes. merge_policy : str or None, optional Which merge policy to apply after dispatch. merge_data : dict, optional Initial data for merge policy. Returns ------- ResultCollector or MergedResultCollector Raw results if `merge_policy` is None, otherwise merged. """ logger.info( "Dispatch start: multiprocess=%s, wait=%s, merge_policy=%s", multiprocess, wait, merge_policy, ) if not isinstance(dispatch, Dispatch): dispatch = Dispatch(dispatch) collector = ResultCollector() jobs = list(dispatch.all_jobs()) if not jobs: logger.warning("No jobs to dispatch.") collector.complete = True return collector if merge_policy is None else MergedResultCollector(collector) for prov, back, job in jobs: collector.register_job_mapping(job, prov, back) def _run_sequential() -> None: """Run all jobs sequentially.""" for prov, back, job in jobs: try: res = run_single_job_static( prov, back, job.circuit, job.shots, job.configuration or {}, self._providers_info, self._providers, self._raise_exc, virtual_provider=self._virtual_provider, ) except Exception as e: # pylint: disable=broad-except logger.error("Error fetching result for Job %s: %s", job.id, e) res = {"error": str(e)} collector.store_result(job, res) collector.complete = True if not multiprocess: if wait: _run_sequential() else: threading.Thread(target=_run_sequential, daemon=True).start() else: executor = ProcessPoolExecutor(max_workers or self._max_workers) futures: dict[Any, Any] = {} for prov, back, job in jobs: futures[ executor.submit( run_single_job_static, prov, back, job.circuit, job.shots, job.configuration or {}, self._providers_info, self._providers, self._raise_exc, None, ) ] = job def _gather() -> None: for fut in as_completed(futures): job_obj = futures[fut] try: res = fut.result() except Exception as e: # pylint: disable=broad-except logger.error("Error fetching result for Job %s: %s", job_obj.id, e) res = {"error": str(e)} collector.store_result(job_obj, res) collector.complete = True executor.shutdown(wait=False) if wait: _gather() else: threading.Thread(target=_gather, daemon=True).start() if merge_policy is None: return collector merged = MergedResultCollector(collector) def _merge_dispatch() -> None: """Merge results using the specified merge policy.""" collector.wait_for_completion() try: merge_fn = self.get_merge_policy(merge_policy) md = merge_data or {} results, final = merge_fn(collector.get_results(), md) except Exception as e: # pylint: disable=broad-except logger.error("Dispatch merge error: %s", e) results, final = {"error": str(e)}, {} md = {} merged.set_merged_results(results, md, final) logger.info("Dispatch merge '%s' done.", merge_policy) if wait: _merge_dispatch() else: threading.Thread(target=_merge_dispatch, daemon=True).start() return merged
[docs] def get_split_policy(self, name: str) -> Callable[..., Any]: """Get a split policy by name. Parameters ---------- name : str Policy name. Returns ------- Callable[..., Any] The split policy function. Raises ------ KeyError If not found or policy lacks a split. """ try: p = self._policies[name]["split"] if not callable(p): raise KeyError(f"Split policy '{name}' not found.") return p except KeyError: raise KeyError(f"Split policy '{name}' not found.") from None
[docs] def get_merge_policy(self, name: str) -> Callable[..., Any]: """Get a merge policy by name. Parameters ---------- name : str Policy name. Returns ------- Callable[..., Any] The merge policy function. Raises ------ KeyError If not found or policy lacks a merge. """ try: p = self._policies[name]["merge"] if not callable(p): raise KeyError(f"Merge policy '{name}' not found.") return p except KeyError: raise KeyError(f"Merge policy '{name}' not found.") from None
[docs] def add_policy( self, name: str, split_policy: Callable[..., Any] | None = None, merge_policy: Callable[..., Any] | None = None, ) -> None: """Dynamically add or update a policy (split and/or merge). Parameters ---------- name : str Policy name. split_policy : Callable[..., Any], optional Split function. merge_policy : Callable[..., Any], optional Merge function. """ entry: dict[str, Callable[..., Any]] = {} if split_policy: entry["split"] = split_policy if merge_policy: entry["merge"] = merge_policy if not entry: raise ValueError("At least one of split_policy or merge_policy must be provided.") self._policies[name] = entry logger.info("Policy '%s' added/updated.", name)
[docs] def add_policy_from_file(self, file_path: str, raise_exc: bool | None = None) -> None: """Load a policy from file and (re)load all policies. Parameters ---------- file_path : str Path to policy file. raise_exc : bool, optional If True, raise on errors. Defaults to constructor setting. """ eff = raise_exc if raise_exc is not None else self._raise_exc add_policy_from_file(file_path, self._policies_folder, raise_exc=eff) self._policies = load_policies_from_folder(self._policies_folder, raise_exc=eff)
@property def policies(self) -> list[str]: """List all loaded policy names. Returns ------- List[str] List of loaded policy names. """ return list(self._policies.keys()) @property def virtual_provider(self) -> VirtualProvider: """Get the virtual provider. Returns ------- VirtualProvider The virtual provider instance. """ return self._virtual_provider
[docs] @staticmethod def default_providers() -> list[str]: """Get a list of default providers. Returns ------- List[str] A list of provider names that are available for use. """ return list(VirtualProvider.default_providers())