From cf5058cf9f52f9e796f663dce54400c4a5c333bf Mon Sep 17 00:00:00 2001 From: uealex <16141991+uealex@user.noreply.gitee.com> Date: Fri, 7 Nov 2025 14:13:59 +0800 Subject: [PATCH] feat: add bitflip benchmark --- cqlib/circuits/bitflip.py | 272 ++++++++++++++++++++++++++++++++++ tests/circuit/test_bitflip.py | 86 +++++++++++ 2 files changed, 358 insertions(+) create mode 100644 cqlib/circuits/bitflip.py create mode 100644 tests/circuit/test_bitflip.py diff --git a/cqlib/circuits/bitflip.py b/cqlib/circuits/bitflip.py new file mode 100644 index 0000000..07ef092 --- /dev/null +++ b/cqlib/circuits/bitflip.py @@ -0,0 +1,272 @@ +# This code is part of cqlib. +# +# (C) Copyright China Telecom Quantum Group, QuantumCTek Co., Ltd., +# Center for Excellence in Quantum Information and Quantum Physics 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +bitflip benchmark +""" + +import numpy as np +import pymatching as pm +from scipy.sparse import csr_matrix, lil_matrix + +from cqlib.circuits import Circuit +from cqlib.circuits.circuit import IntQubit + + +class TEMPCIR(Circuit): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def CONDX(self, qubit: IntQubit): + """ + Temporary conditional X gate implementation. + + Args: + qubit (IntQubit): Target qubit for the conditional X operation. + + Returns: + None: This is a temporary implementation that only prints a message. + """ + # Temporary implementation + print("Temporary fallback for missing_function") + return None + + +class RepetitionCode: + def __init__(self, num_rounds: int, code_distance: int = 5) -> None: + """ + Initialize a repetition code error correction circuit. + + Args: + num_rounds (int): Number of error correction rounds to perform. + Must be >= 1. + code_distance (int, optional): Distance of the repetition code. + Determines number of physical qubits. Defaults to 5. + + Raises: + AssertionError: If num_rounds < 1. + """ + assert num_rounds >= 1, "at least 1 rounds" + self.num_rounds = num_rounds + self.code_distance = code_distance + self.num_ancilla = code_distance - 1 + self.num_qubits = self.code_distance + self.num_ancilla + + self.data_qubits = [2 * i for i in range(self.code_distance)] + self.ancilla_qubits = [2 * i + 1 for i in range(self.num_ancilla)] + + self.circuit = self.gen_cir() + + def _cx(self, qc: Circuit, control: int, target: int): + """ + Apply a CNOT gate using the circuit's native CZ gate. + + Implements CNOT using the identity: CNOT = H ⋅ CZ ⋅ H + + Args: + qc (Circuit): The quantum circuit to apply the gate to. + control (int): Index of the control qubit. + target (int): Index of the target qubit. + """ + qc.h(qubit=target) + qc.cz(control_qubit=control, target_qubit=target) + qc.h(qubit=target) + + def gen_cir(self): + """ + Generate the repetition code circuit for error detection. + + Creates a circuit that: + 1. Initializes all qubits using measurement and conditional-X + 2. For each round: + - Entangles ancilla qubits with adjacent data qubits using CNOT + - Measures and resets ancilla qubits + 3. Finally measures all data qubits + + Returns: + TEMPCIR: The constructed quantum circuit with measurements. + """ + # init + # qc:Circuit = Circuit(qubits=self.num_qubits) + qc: TEMPCIR = TEMPCIR(qubits=self.num_qubits) + # reset all qubits + for q_idx in range(self.num_qubits): + qc.measure(q_idx) + qc.CONDX(q_idx) + + for _ in range(self.num_rounds): + # each round, cnot between ancilla with data qubits + for q_idx in self.ancilla_qubits: + self._cx(qc, q_idx - 1, q_idx) + self._cx(qc, q_idx + 1, q_idx) + for q_idx in self.ancilla_qubits: + # this measure result should be saved + qc.measure(q_idx) + qc.CONDX(q_idx) + # Final read-out + for q_idx in self.data_qubits: + # this measure result should be saved + qc.measure(q_idx) + + def check_matrix(self) -> csr_matrix: + """ + Get the parity check matrix for the repetition code. + + The parity check matrix H encodes the stabilizer measurements and their + relationships across rounds. It is used by the decoder to identify + likely error patterns from syndrome measurements. + + Returns: + csr_matrix: Sparse parity check matrix in compressed row format. + """ + return build_repetition_code_parity_check_matrix( + self.code_distance, self.num_rounds + ) + + +def build_repetition_code_parity_check_matrix( + code_distance: int, num_rounds: int +) -> csr_matrix: + """Build sparse parity-check matrix for the repetition code. + + Constructs a binary parity check matrix H that relates measurement outcomes + to error syndromes in the repetition code. The matrix has the structure: + + [ Time-like detectors | Bridge detectors | Logical observable ] + + where: + - Time-like detectors: Compare ancilla measurements between rounds + - Bridge detectors: Connect final ancilla measurements to data qubits + - Logical observable: Track the logical qubit value + + Matrix dimensions are: + - Rows: (num_ancilla * num_rounds + code_distance) measurements + - Cols: (num_ancilla * (num_rounds + 1) + 1) detectors/observables + + Args: + code_distance (int): Distance of the repetition code. Determines + number of physical data qubits (code_distance) and number of + ancilla qubits (code_distance - 1). + num_rounds (int): Number of measurement rounds performed. + Each round measures all ancilla qubits. + + Returns: + csr_matrix: Sparse parity check matrix in compressed row format. + Matrix entries are boolean (0/1) indicating which measurements + participate in each syndrome or logical observable. + + Example: + For d=3, r=2 (distance 3, 2 rounds), matrix structure is: + - 2 ancilla measurements per round × 2 rounds = 4 rows + - 3 final data measurements = 3 more rows + - 2 time-like detectors per round = 4 columns + - 2 bridge detectors = 2 more columns + - 1 logical observable = 1 more column + """ + num_ancilla = code_distance - 1 + num_measurements = num_ancilla * num_rounds + code_distance + num_time_like_detectors = num_ancilla * num_rounds + num_bridge_detectors = num_ancilla + num_logical_observables = 1 + + H = lil_matrix( + ( + num_measurements, + num_time_like_detectors + num_bridge_detectors + num_logical_observables, + ), + dtype=bool, + ) + + column_index = 0 + for round_index in range(num_rounds): + for ancilla_index in range(num_ancilla): + current_measurement = round_index * num_ancilla + ancilla_index + if round_index == 0: + H[current_measurement, column_index] = 1 + else: + previous_measurement = (round_index - 1) * num_ancilla + ancilla_index + H[current_measurement, column_index] = 1 + H[previous_measurement, column_index] = 1 + column_index += 1 + + data_start_index = num_ancilla * num_rounds + for ancilla_index in range(num_ancilla): + last_measurement = (num_rounds - 1) * num_ancilla + ancilla_index + left_data_qubit = data_start_index + ancilla_index + right_data_qubit = data_start_index + ancilla_index + 1 + H[last_measurement, column_index] = 1 + H[left_data_qubit, column_index] = 1 + H[right_data_qubit, column_index] = 1 + column_index += 1 + + rightmost_data_qubit = data_start_index + code_distance - 1 + H[rightmost_data_qubit, column_index] = 1 + + return H.tocsr() + + +def build_pymatching_decoder( + parity_check_matrix: csr_matrix, physical_error_probability: float +) -> pm.Matching: + """ + Create a minimum-weight perfect matching decoder using pymatching. + + Args: + parity_check_matrix (csr_matrix): The sparse parity check matrix H + describing the code's stabilizer measurements. + physical_error_probability (float): The probability of a physical + error occurring on any single qubit. + + Returns: + pm.Matching: Configured pymatching decoder ready to process syndromes. + The decoder uses error weights derived from the physical error rate. + """ + detector_submatrix = parity_check_matrix[:, :-1].T.astype(np.uint8) + log_likelihood_ratio = np.log( + (1 - physical_error_probability) / physical_error_probability + ) + weights = np.ones(parity_check_matrix.shape[1]) * log_likelihood_ratio + return pm.Matching(detector_submatrix, weights=weights) + + +def decode_logical_observables( + decoder: pm.Matching, + parity_check_matrix: csr_matrix, + measurement_results: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """ + Decode logical qubit values from measurement results using the decoder. + + Args: + decoder (pm.Matching): The initialized pymatching decoder. + parity_check_matrix (csr_matrix): The parity check matrix H that + defines the code structure. + measurement_results (np.ndarray): Raw measurement results from the + quantum circuit, including both stabilizer and data measurements. + + Returns: + tuple[np.ndarray, np.ndarray]: A tuple containing: + - predicted_logical_values: The decoder's best guess of the + logical qubit values after error correction + - actual_logical_values: The raw logical observable values + before error correction + """ + syndrome_matrix = (measurement_results @ parity_check_matrix) & 1 + syndrome_bits = syndrome_matrix[:, :-1].astype(np.uint8) + actual_logical_values = syndrome_matrix[:, -1] + + predicted_error_patterns = np.vstack([decoder.decode(row) for row in syndrome_bits]) + predicted_logical_values = ( + predicted_error_patterns @ parity_check_matrix[:, -1].toarray().flatten() + ) & 1 + + return predicted_logical_values, actual_logical_values diff --git a/tests/circuit/test_bitflip.py b/tests/circuit/test_bitflip.py new file mode 100644 index 0000000..ba6d590 --- /dev/null +++ b/tests/circuit/test_bitflip.py @@ -0,0 +1,86 @@ +# This code is part of cqlib. +# +# (C) Copyright China Telecom Quantum Group, QuantumCTek Co., Ltd., +# Center for Excellence in Quantum Information and Quantum Physics 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test bitflip +""" + +import numpy as np +import pymatching as pm +from scipy.sparse import csr_matrix + +from cqlib.circuits.bitflip import ( + build_pymatching_decoder, + build_repetition_code_parity_check_matrix, + decode_logical_observables, +) + + +def small_parity_check_matrix(): + return build_repetition_code_parity_check_matrix(code_distance=3, num_rounds=2) + + +def test_parity_check_matrix_shape(): + H = small_parity_check_matrix() + # For distance=3, num_rounds=2: A=2 + # rows = 2*2 + 3 = 7 + # cols = 2*2 + 2 + 1 = 7 + assert H.shape == (7, 7) + assert isinstance(H, csr_matrix) + assert H.nnz > 0 # non-zero entries + + +def test_build_pymatching_decoder(): + error_prob = 0.1 + decoder = build_pymatching_decoder(small_parity_check_matrix(), error_prob) + assert isinstance(decoder, pm.Matching) + assert decoder.num_detectors > 0 + + +def test_decode_logical_observables_returns_array(): + code_distance = 5 + num_rounds = 5 + num_shots = 10 + depolarization = 0.4 + + # simulate measurement + parity_check_matrix = build_repetition_code_parity_check_matrix( + code_distance, num_rounds + ) + raw_measurements = np.array( + [ + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0], + [1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], + [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0], + [0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0], + [1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0], + [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0], + ] + ) + + decoder = build_pymatching_decoder( + parity_check_matrix, physical_error_probability=depolarization + ) + predicted_logical, actual_logical = decode_logical_observables( + decoder, parity_check_matrix, raw_measurements + ) + + print(predicted_logical, actual_logical) + + assert isinstance(predicted_logical, np.ndarray) + assert predicted_logical.shape == (num_shots,) + assert actual_logical.shape == (num_shots,) + assert np.all((predicted_logical == 0) | (predicted_logical == 1)) # only 0 or 1 -- Gitee