Source code for pytket.extensions.stim.backends.stim_backend

# Copyright 2020-2024 Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast, List, Optional, Sequence, Union
from uuid import uuid4

import numpy as np
import stim  # type: ignore

from pytket.backends import (
    Backend,
    CircuitNotRunError,
    CircuitStatus,
    ResultHandle,
    StatusEnum,
)
from pytket.backends.backendinfo import BackendInfo
from pytket.backends.backendresult import BackendResult
from pytket.backends.resulthandle import _ResultIdTuple
from pytket.circuit import Circuit, OpType
from pytket.passes import (
    BasePass,
    DecomposeBoxes,
    FlattenRegisters,
    RebaseCustom,
    RemoveRedundancies,
    SequencePass,
)
from pytket.predicates import (
    DefaultRegisterPredicate,
    GateSetPredicate,
    NoClassicalControlPredicate,
    Predicate,
)
from pytket.unit_id import Qubit
from pytket.utils.outcomearray import OutcomeArray
from pytket.utils.results import KwargTypes
from pytket.extensions.stim._metadata import __extension_version__

_gate = {
    OpType.noop: "I",
    OpType.X: "X",
    OpType.Y: "Y",
    OpType.Z: "Z",
    OpType.H: "H",
    OpType.S: "S",
    OpType.SX: "SQRT_X",
    OpType.SXdg: "SQRT_X_DAG",
    OpType.CX: "CX",
    OpType.CY: "CY",
    OpType.CZ: "CZ",
    OpType.ISWAPMax: "ISWAP",
    OpType.SWAP: "SWAP",
    OpType.Measure: "M",
    OpType.Reset: "R",
}
_backend_info = BackendInfo(
    "StimBackend",
    None,
    __extension_version__,
    None,
    set(_gate.keys()),
)


def _int_double(x: float) -> int:
    # return (2x) mod 8 if x is close to a half-integer, otherwise error
    y = 2 * x
    n = int(np.round(y))
    if np.isclose(y, n):
        return n % 8
    else:
        raise ValueError("Non-Clifford angle encountered")


def _tk1_to_cliff(a: float, b: float, c: float) -> Circuit:
    # Convert Clifford tk1(a, b, c) to a circuit composed of H and S gates
    n_a, n_b, n_c = _int_double(a), _int_double(b), _int_double(c)
    circ = Circuit(1)
    for _ in range(n_c):
        circ.S(0)
    for _ in range(n_b):
        circ.H(0).S(0).H(0)
    for _ in range(n_a):
        circ.S(0)
    circ.add_phase(-0.25 * (n_a + n_b + n_c))
    return circ


def _process_one_circuit(circ: Circuit, n_shots: int) -> BackendResult:
    qubits = circ.qubits
    bits = circ.bits
    c = stim.Circuit()
    readout_bits = []
    for cmd in circ.get_commands():
        optype = cmd.op.type
        args = cmd.args
        if optype == OpType.Measure:
            qb, cb = args
            assert isinstance(qb, Qubit)
            c.append_operation("M", [qubits.index(qb)])
            readout_bits.append(cb)
        else:
            qbs = [qubits.index(cast(Qubit, arg)) for arg in args]
            c.append_operation(_gate[optype], qbs)
        if len(set(readout_bits)) != len(readout_bits):
            raise ValueError("Measurement overwritten")
    sampler = c.compile_sampler()
    batch = sampler.sample(n_shots)
    # batch[k,:] has the measurements in the order they were added to the stim circuit.
    # We want them to be returned in bit order.
    return BackendResult(
        shots=OutcomeArray.from_readouts(
            [[batch[k, readout_bits.index(cb)] for cb in bits] for k in range(n_shots)]
        )
    )


[docs] class StimBackend(Backend): """ Backend for simulating Clifford circuits using Stim """ _supports_shots = True _supports_counts = True @property def required_predicates(self) -> List[Predicate]: return [ DefaultRegisterPredicate(), GateSetPredicate(set(_gate.keys())), NoClassicalControlPredicate(), ]
[docs] def rebase_pass(self) -> BasePass: return RebaseCustom({OpType.CX, OpType.H, OpType.S}, Circuit(), _tk1_to_cliff) # type: ignore
[docs] def default_compilation_pass(self, optimisation_level: int = 1) -> BasePass: # No optimization. return SequencePass( [ DecomposeBoxes(), FlattenRegisters(), self.rebase_pass(), RemoveRedundancies(), ] )
@property def backend_info(self) -> Optional[BackendInfo]: return _backend_info @property def _result_id_type(self) -> _ResultIdTuple: return (str,)
[docs] def process_circuits( self, circuits: Sequence[Circuit], n_shots: Optional[Union[int, Sequence[int]]] = None, valid_check: bool = True, **kwargs: KwargTypes, ) -> List[ResultHandle]: circuits = list(circuits) n_shots_list: List[int] = [] if hasattr(n_shots, "__iter__"): n_shots_list = cast(List[int], n_shots) if len(n_shots_list) != len(circuits): raise ValueError("The length of n_shots and circuits must match") else: # convert n_shots to a list n_shots_list = [cast(int, n_shots)] * len(circuits) if valid_check: self._check_all_circuits(circuits) handle_list = [] for circuit, n_shots_circ in zip(circuits, n_shots_list): handle = ResultHandle(str(uuid4())) self._cache[handle] = { "result": _process_one_circuit(circuit, n_shots_circ) } handle_list.append(handle) return handle_list
[docs] def circuit_status(self, handle: ResultHandle) -> CircuitStatus: if handle in self._cache: return CircuitStatus(StatusEnum.COMPLETED) raise CircuitNotRunError(handle)