Skip to content

API Reference

The starting point for PlanqTN Python API is the planqtn package.

PlanqTN is a library for creating and analyzing tensor network quantum error correction codes.

To build tensor network codes manually, use the planqtn.TensorNetwork class and the planqtn.StabilizerCodeTensorEnumerator class for nodes alongside with the planqtn.Legos module for predefined parity check matrices.

Example

Put together a tensor network from stabilizer code tensors and compute the weight enumerator polynomial.

>>> from planqtn import TensorNetwork
>>> from planqtn import StabilizerCodeTensorEnumerator
>>> from planqtn import Legos
>>> # Create tensor network from stabilizer code tensors
>>> nodes = [StabilizerCodeTensorEnumerator(tensor_id="z0", h=Legos.z_rep_code(3)),
...          StabilizerCodeTensorEnumerator(tensor_id="x1", h=Legos.x_rep_code(3)),
...          StabilizerCodeTensorEnumerator(tensor_id="z2", h=Legos.z_rep_code(3))]
>>> tn = TensorNetwork(nodes)
>>> # Add traces to define contraction pattern
>>> tn.self_trace("z0", "x1", [0], [0])
>>> tn.self_trace("x1", "z2", [1], [0])
>>> # Compute weight enumerator polynomial
>>> wep = tn.stabilizer_enumerator_polynomial()
>>> print(wep)
{0:1, 2:2, 3:8, 4:13, 5:8}

To build tensor network codes automatically, you can use classes in the planqtn.networks module, which contain universal tensor network layouts for stabilizer codes as well for specific codes.

Example

Generate the tensor network for the 5x5 rotated surface code and calculate the weight enumerator polynomial.

>>> from planqtn.networks import RotatedSurfaceCodeTN
>>> tn = RotatedSurfaceCodeTN(5)
>>> for power, coeff in tn.stabilizer_enumerator_polynomial().items():
...     print(f"{power}: {coeff}")
0: 1
2: 8
4: 72
6: 534
8: 3715
10: 25816
12: 158448
14: 782532
16: 2726047
18: 5115376
20: 5136632
22: 2437206
24: 390829

Legos

Collection of predefined quantum error correction tensor "legos".

This class provides a library of pre-defined stabilizer code tensors and quantum operations that can be used as building blocks for quantum error correction codes. Each lego represents a specific quantum code or operation with its associated parity check matrix.

The class includes various types of tensors:

  • Encoding tensors for specific quantum codes ([[6,0,3]], [[5,1,2]], etc.)
  • Repetition codes for basic error correction
  • Stopper tensors for terminating tensor networks
  • Identity and Hadamard operations
  • Well-known codes like the Steane code and Quantum Reed-Muller codes
Example
>>> from planqtn.symplectic import sprint
>>> # Get the Hadamard tensor
>>> Legos.h
GF([[1, 0, 0, 1],
    [0, 1, 1, 0]], order=2)
>>> # Get the stopper_x tensor
>>> Legos.stopper_x
GF([[1, 0]], order=2)
>>> # Get the stopper_z tensor
>>> Legos.stopper_z
GF([[0, 1]], order=2)
>>> # Get a Z-repetition code with distance 3
>>> # and print it in a nice symplectic format
>>> sprint(Legos.z_rep_code(d=3))
___|11_
___|_11
111|___
Source code in planqtn/legos.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class Legos:
    """Collection of predefined quantum error correction tensor "legos".

    This class provides a library of pre-defined stabilizer code tensors
    and quantum operations that can be used as building blocks for quantum
    error correction codes. Each lego represents a specific quantum code
    or operation with its associated parity check matrix.

    The class includes various types of tensors:

    - Encoding tensors for specific quantum codes (`[[6,0,3]]`, `[[5,1,2]]`, etc.)
    - Repetition codes for basic error correction
    - Stopper tensors for terminating tensor networks
    - Identity and Hadamard operations
    - Well-known codes like the Steane code and Quantum Reed-Muller codes

    Example:
        ```python
        >>> from planqtn.symplectic import sprint
        >>> # Get the Hadamard tensor
        >>> Legos.h
        GF([[1, 0, 0, 1],
            [0, 1, 1, 0]], order=2)
        >>> # Get the stopper_x tensor
        >>> Legos.stopper_x
        GF([[1, 0]], order=2)
        >>> # Get the stopper_z tensor
        >>> Legos.stopper_z
        GF([[0, 1]], order=2)
        >>> # Get a Z-repetition code with distance 3
        >>> # and print it in a nice symplectic format
        >>> sprint(Legos.z_rep_code(d=3))
        ___|11_
        ___|_11
        111|___

        ```
    """

    enconding_tensor_603 = GF2(
        [
            [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0],
            [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0],
            [0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1],
        ]
    )

    stab_code_parity_422 = GF2(
        [
            [1, 1, 1, 1, 0, 0, 0, 0],
            [0, 0, 0, 0, 1, 1, 1, 1],
        ]
    )

    # fmt: off
    steane_code_813_encoding_tensor = GF2([
        [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
    ])
    # fmt: on

    @staticmethod
    def z_rep_code(d: int = 3) -> GF2:
        """Generate a Z-type repetition code parity check matrix.

        Creates a repetition code that protects against bit-flip errors using
        Z-type stabilizers. The code has distance d and encodes 1 logical qubit
        in d physical qubits. It is also the Z-spider in the ZX-calculus.

        Args:
            d: Distance of the repetition code (default: 3).

        Returns:
            GF2: Parity check matrix for the Z-repetition code.
        """
        gens = []
        for i in range(d - 1):
            g = GF2.Zeros(2 * d)
            g[[d + i, d + i + 1]] = 1
            gens.append(g)
        g = GF2.Zeros(2 * d)
        g[np.arange(d)] = 1
        gens.append(g)
        return GF2(gens)

    @staticmethod
    def x_rep_code(d: int = 3) -> GF2:
        """Generate an X-type repetition code parity check matrix.

        Creates a repetition code that protects against phase-flip errors using
        X-type stabilizers. The code has distance d and encodes 1 logical qubit
        in d physical qubits. It is also the X-spider in the ZX-calculus.

        Args:
            d: Distance of the repetition code (default: 3).

        Returns:
            GF2: Parity check matrix for the X-repetition code.
        """
        gens = []
        for i in range(d - 1):
            g = GF2.Zeros(2 * d)
            g[[i, i + 1]] = 1
            gens.append(g)
        g = GF2.Zeros(2 * d)
        g[np.arange(d, 2 * d)] = 1
        gens.append(g)
        return GF2(gens)

    identity = GF2(
        [
            [1, 1, 0, 0],
            [0, 0, 1, 1],
        ]
    )
    """the identity tensor is the Bell state, the |00> + |11> state"""

    encoding_tensor_512 = GF2(
        [
            [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
            [1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
        ]
    )
    """the [[5,1,2]] subspace tensor of the [[4,2,2]] code, i.e. with the logical leg, leg 5 traced
    out with the identity stopper from the [[6,0,3]] encoding tensor."""

    encoding_tensor_512_x = GF2(
        [
            [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
        ]
    )
    """the X-only version of the [planqtn.Legos.encoding_tensor_512][]"""

    encoding_tensor_512_z = GF2(
        [
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
            [0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
        ]
    )
    """the Z-only version of the [planqtn.Legos.encoding_tensor_512][]"""

    h = GF2(
        [
            [1, 0, 0, 1],
            [0, 1, 1, 0],
        ]
    )
    """the Hadamard tensor"""

    stopper_x = GF2([Pauli.X.to_gf2()])
    """the X-type stopper tensor, the |+> state, corresponds to the Pauli X operator."""

    stopper_z = GF2([Pauli.Z.to_gf2()])
    """the Z-type stopper tensor, the |0> state, corresponds to the Pauli Z operator."""

    stopper_y = GF2([Pauli.Y.to_gf2()])
    """the Y-type stopper tensor, the |+i> state, corresponds to the Pauli Y operator."""

    stopper_i = GF2([Pauli.I.to_gf2()])
    """the identity stopper tensor, which is the free qubit subspace, corresponds to the
    Pauli I operator."""

encoding_tensor_512 = GF2([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], [1, 1, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 0, 1]]) class-attribute instance-attribute

the [[5,1,2]] subspace tensor of the [[4,2,2]] code, i.e. with the logical leg, leg 5 traced out with the identity stopper from the [[6,0,3]] encoding tensor.

encoding_tensor_512_x = GF2([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0, 0, 0, 0]]) class-attribute instance-attribute

the X-only version of the planqtn.Legos.encoding_tensor_512

encoding_tensor_512_z = GF2([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 1, 1, 0, 1]]) class-attribute instance-attribute

the Z-only version of the planqtn.Legos.encoding_tensor_512

h = GF2([[1, 0, 0, 1], [0, 1, 1, 0]]) class-attribute instance-attribute

the Hadamard tensor

identity = GF2([[1, 1, 0, 0], [0, 0, 1, 1]]) class-attribute instance-attribute

the identity tensor is the Bell state, the |00> + |11> state

stopper_i = GF2([Pauli.I.to_gf2()]) class-attribute instance-attribute

the identity stopper tensor, which is the free qubit subspace, corresponds to the Pauli I operator.

stopper_x = GF2([Pauli.X.to_gf2()]) class-attribute instance-attribute

the X-type stopper tensor, the |+> state, corresponds to the Pauli X operator.

stopper_y = GF2([Pauli.Y.to_gf2()]) class-attribute instance-attribute

the Y-type stopper tensor, the |+i> state, corresponds to the Pauli Y operator.

stopper_z = GF2([Pauli.Z.to_gf2()]) class-attribute instance-attribute

the Z-type stopper tensor, the |0> state, corresponds to the Pauli Z operator.

x_rep_code(d=3) staticmethod

Generate an X-type repetition code parity check matrix.

Creates a repetition code that protects against phase-flip errors using X-type stabilizers. The code has distance d and encodes 1 logical qubit in d physical qubits. It is also the X-spider in the ZX-calculus.

Parameters:

Name Type Description Default
d int

Distance of the repetition code (default: 3).

3

Returns:

Name Type Description
GF2 GF2

Parity check matrix for the X-repetition code.

Source code in planqtn/legos.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@staticmethod
def x_rep_code(d: int = 3) -> GF2:
    """Generate an X-type repetition code parity check matrix.

    Creates a repetition code that protects against phase-flip errors using
    X-type stabilizers. The code has distance d and encodes 1 logical qubit
    in d physical qubits. It is also the X-spider in the ZX-calculus.

    Args:
        d: Distance of the repetition code (default: 3).

    Returns:
        GF2: Parity check matrix for the X-repetition code.
    """
    gens = []
    for i in range(d - 1):
        g = GF2.Zeros(2 * d)
        g[[i, i + 1]] = 1
        gens.append(g)
    g = GF2.Zeros(2 * d)
    g[np.arange(d, 2 * d)] = 1
    gens.append(g)
    return GF2(gens)

z_rep_code(d=3) staticmethod

Generate a Z-type repetition code parity check matrix.

Creates a repetition code that protects against bit-flip errors using Z-type stabilizers. The code has distance d and encodes 1 logical qubit in d physical qubits. It is also the Z-spider in the ZX-calculus.

Parameters:

Name Type Description Default
d int

Distance of the repetition code (default: 3).

3

Returns:

Name Type Description
GF2 GF2

Parity check matrix for the Z-repetition code.

Source code in planqtn/legos.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@staticmethod
def z_rep_code(d: int = 3) -> GF2:
    """Generate a Z-type repetition code parity check matrix.

    Creates a repetition code that protects against bit-flip errors using
    Z-type stabilizers. The code has distance d and encodes 1 logical qubit
    in d physical qubits. It is also the Z-spider in the ZX-calculus.

    Args:
        d: Distance of the repetition code (default: 3).

    Returns:
        GF2: Parity check matrix for the Z-repetition code.
    """
    gens = []
    for i in range(d - 1):
        g = GF2.Zeros(2 * d)
        g[[d + i, d + i + 1]] = 1
        gens.append(g)
    g = GF2.Zeros(2 * d)
    g[np.arange(d)] = 1
    gens.append(g)
    return GF2(gens)

StabilizerCodeTensorEnumerator

Tensor enumerator for a stabilizer code.

Source code in planqtn/stabilizer_tensor_enumerator.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
class StabilizerCodeTensorEnumerator:
    """Tensor enumerator for a stabilizer code."""

    def __init__(
        self,
        h: GF2,
        tensor_id: TensorId = 0,
        legs: Optional[List[TensorLeg]] = None,
        coset_flipped_legs: Optional[List[Tuple[Tuple[Any, int], GF2]]] = None,
        annotation: Optional[LegoAnnotation] = None,
    ):
        """Construct a stabilizer code tensor enumerator.

        A `StabilizerCodeTensorEnumerator` is basically an object oriented wrapper around
        a parity check matrix. It supports self-tracing, as well as tensor product, and conjoining
        of with other `StabilizerCodeTensorEnumerator` instances. As such, it is the building block
        of tensor networks in the [TensorNetwork][planqtn.tensor_network.TensorNetwork] class.

        The class also supports the enumeration of the scalar stabilizer weight enumerator of the
        code via brute force. There can be legs left open, in which case the weight enumerator
        becomes a tensor weight enumerator. Weight truncation is supported for approximate
        enumeration. Coset support is represented by `coset_flipped_legs`.

        Args:
            h: The parity check matrix.
            tensor_id: The ID of the tensor.
            legs: The legs of the tensor.
            coset_flipped_legs: The coset flipped legs of the tensor.
            annotation: The annotation of the tensor for hints for visualization in PlanqTN Studio.

        Raises:
            AssertionError: If the legs are not valid.
        """
        self.h = h
        self.annotation = annotation

        self.tensor_id = tensor_id
        if len(self.h.shape) == 1:
            self.n = self.h.shape[0] // 2
            self.k = self.n - 1
        else:
            self.n = self.h.shape[1] // 2
            self.k = self.n - self.h.shape[0]

        self.legs = (
            [(self.tensor_id, leg) for leg in range(self.n)] if legs is None else legs
        )
        # print(f"Legs: {self.legs} because n = {self.n}, {self.h.shape}")
        assert (
            len(self.legs) == self.n
        ), f"Number of legs {len(self.legs)} != qubit count {self.n} for h: {self.h}"
        # a dict is a wonky tensor - TODO: rephrase this to proper tensor
        self._stabilizer_enums: Dict[sympy.Tuple, UnivariatePoly] = {}

        self.coset_flipped_legs = []
        if coset_flipped_legs is not None:
            self.coset_flipped_legs = coset_flipped_legs
            for leg, pauli in self.coset_flipped_legs:
                assert (
                    leg in self.legs
                ), f"Leg in coset not found: {leg} - legs: {self.legs}"
                assert len(pauli) == 2 and isinstance(
                    pauli, GF2
                ), f"Invalid pauli in coset: {pauli} on leg {leg}"
            # print(f"Coset flipped legs validated. Setting to {self.coset_flipped_legs}")

    def __str__(self) -> str:
        return f"TensorEnum({self.tensor_id})"

    def __repr__(self) -> str:
        return f"TensorEnum({self.tensor_id})"

    def set_tensor_id(self, tensor_id: TensorId) -> None:
        """Set the tensor ID and update all legs to use the new ID.

        Updates the tensor_id attribute and modifies all legs that reference
        the old tensor_id to use the new one.

        Args:
            tensor_id: New tensor ID to assign to this tensor.
        """
        for l, leg in enumerate(self.legs):
            if leg[0] == self.tensor_id:
                self.legs[l] = (tensor_id, leg[1])
        self.tensor_id = tensor_id

    def _key(self, e: GF2) -> Tuple[int, ...]:
        return tuple(e.astype(np.uint8).tolist())

    def is_stabilizer(self, op: GF2) -> bool:
        """Check if an operator is a stabilizer of this code.

        Determines whether the given operator commutes with all stabilizers
        of the code by checking if op * omega * h^T = 0.

        Args:
            op: Operator to check (as GF2 vector).

        Returns:
            bool: True if op is a stabilizer, False otherwise.
        """
        return 0 == np.count_nonzero(op @ omega(self.n) @ self.h.T)

    def _remove_leg(self, legs: Dict[TensorLeg, int], leg: TensorLeg) -> None:
        pos = legs[leg]
        del legs[leg]
        for k in legs.keys():
            if legs[k] > pos:
                legs[k] -= 1

    def _remove_legs(
        self, legs: Dict[TensorLeg, int], legs_to_remove: List[TensorLeg]
    ) -> None:
        for leg in legs_to_remove:
            self._remove_leg(legs, leg)

    def _validate_legs(self, legs: List[TensorLeg]) -> List[TensorLeg]:
        return [leg for leg in legs if leg not in self.legs]

    def with_coset_flipped_legs(
        self, coset_flipped_legs: List[Tuple[TensorLeg, GF2]]
    ) -> "StabilizerCodeTensorEnumerator":
        """Create a new tensor enumerator with coset-flipped legs.

        Creates a copy of this tensor enumerator with the specified coset-flipped
        legs. This is used for coset weight enumerator calculations.

        Args:
            coset_flipped_legs: List of (leg, coset_error) pairs specifying
                which legs have coset errors applied.

        Returns:
            StabilizerCodeTensorEnumerator: New tensor enumerator with coset-flipped legs.
        """
        return StabilizerCodeTensorEnumerator(
            self.h, self.tensor_id, self.legs, coset_flipped_legs
        )

    def tensor_with(
        self, other: "StabilizerCodeTensorEnumerator"
    ) -> "StabilizerCodeTensorEnumerator":
        """Create the tensor product with another tensor enumerator.

        Computes the tensor product of this tensor with another tensor enumerator.
        The resulting tensor has the combined parity check matrix and all legs
        from both tensors.

        Args:
            other: The other tensor enumerator to tensor with.

        Returns:
            StabilizerCodeTensorEnumerator: The tensor product of the two tensors.
        """
        new_h = tensor_product(self.h, other.h)
        if np.array_equal(new_h, GF2([[0]])):
            return StabilizerCodeTensorEnumerator(
                new_h, tensor_id=self.tensor_id, legs=[]
            )
        return StabilizerCodeTensorEnumerator(
            new_h, tensor_id=self.tensor_id, legs=self.legs + other.legs
        )

    def self_trace(
        self, legs1: Sequence[int | TensorLeg], legs2: Sequence[int | TensorLeg]
    ) -> "StabilizerCodeTensorEnumerator":
        """Perform self-tracing by contracting pairs of legs within this tensor.

        Contracts pairs of legs within the same tensor, effectively performing
        a partial trace operation. The legs are paired up and contracted together.

        Args:
            legs1: First set of legs to contract (must match length of legs2).
            legs2: Second set of legs to contract (must match length of legs1).

        Returns:
            StabilizerCodeTensorEnumerator: New tensor with contracted legs removed.

        Raises:
            AssertionError: If legs1 and legs2 have different lengths.
        """
        assert len(legs1) == len(legs2)
        legs1_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs1)
        legs2_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs2)
        leg2col = {leg: i for i, leg in enumerate(self.legs)}

        new_h = self.h
        for leg1, leg2 in zip(legs1_indexed, legs2_indexed):
            new_h = self_trace(new_h, leg2col[leg1], leg2col[leg2])
            self._remove_legs(leg2col, [leg1, leg2])

        new_legs = [
            leg
            for leg in self.legs
            if leg not in legs1_indexed and leg not in legs2_indexed
        ]
        return StabilizerCodeTensorEnumerator(
            new_h, tensor_id=self.tensor_id, legs=new_legs
        )

    def conjoin(
        self,
        other: "StabilizerCodeTensorEnumerator",
        legs1: Sequence[int | TensorLeg],
        legs2: Sequence[int | TensorLeg],
    ) -> "StabilizerCodeTensorEnumerator":
        """Creates a new tensor enumerator by conjoining two of them.

        Creates a new tensor enumerator by contracting the specified legs between
        this tensor and another tensor. The legs of the other tensor will become
        the legs of the new tensor.

        Args:
            other: The other tensor enumerator to conjoin with.
            legs1: Legs from this tensor to contract.
            legs2: Legs from the other tensor to contract.

        Returns:
            StabilizerCodeTensorEnumerator: The conjoined tensor enumerator.
        """
        if self.tensor_id == other.tensor_id:
            return self.self_trace(legs1, legs2)
        assert len(legs1) == len(legs2)
        legs1_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs1)
        legs2_indexed: List[TensorLeg] = _index_legs(other.tensor_id, legs2)

        leg2col = {leg: i for i, leg in enumerate(self.legs)}
        # for example 2 3 4 | 2 4 8 will become
        # as legs2_offset = 5
        # {2: 0, 3: 1, 4: 2, 7: 3, 11: 4, 13: 5}
        leg2col.update({leg: len(self.legs) + i for i, leg in enumerate(other.legs)})

        new_h = conjoin(
            self.h,
            other.h,
            self.legs.index(legs1_indexed[0]),
            other.legs.index(legs2_indexed[0]),
        )
        self._remove_legs(leg2col, [legs1_indexed[0], legs2_indexed[0]])

        for leg1, leg2 in zip(legs1_indexed[1:], legs2_indexed[1:]):
            new_h = self_trace(new_h, leg2col[leg1], leg2col[leg2])
            self._remove_legs(leg2col, [leg1, leg2])

        new_legs = [leg for leg in self.legs if leg not in legs1_indexed]
        new_legs += [leg for leg in other.legs if leg not in legs2_indexed]

        return StabilizerCodeTensorEnumerator(
            new_h, tensor_id=self.tensor_id, legs=new_legs
        )

    def _brute_force_stabilizer_enumerator_from_parity(
        self,
        open_legs: Sequence[TensorLeg] = (),
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
        truncate_length: Optional[int] = None,
    ) -> Union[TensorEnumerator, UnivariatePoly]:

        open_legs = _index_legs(self.tensor_id, open_legs)
        invalid_legs = self._validate_legs(open_legs)
        if len(invalid_legs) > 0:
            raise ValueError(
                f"Can't leave legs open for tensor: {invalid_legs}, they don't exist on node "
                f"{self.tensor_id} with legs:\n{self.legs}"
            )

        open_cols = [self.legs.index(leg) for leg in open_legs]

        coset = GF2.Zeros(2 * self.n)
        if self.coset_flipped_legs is not None:
            for leg, pauli in self.coset_flipped_legs:
                assert leg in self.legs, f"Leg in coset not found: {leg}"
                assert len(pauli) == 2 and isinstance(
                    pauli, GF2
                ), f"Invalid pauli in coset: {pauli} on leg {leg}"
                coset[self.legs.index(leg)] = pauli[0]
                coset[self.legs.index(leg) + self.n] = pauli[1]

        collector = (
            _SimpleStabilizerCollector(
                coset,
                open_cols,
                verbose,
                truncate_length=truncate_length,
            )
            if open_cols == []
            else _TensorElementCollector(
                coset,
                open_cols,
                verbose,
                progress_reporter,
                truncate_length=truncate_length,
            )
        )

        h_reduced = gauss(self.h)
        h_reduced = h_reduced[~np.all(h_reduced == 0, axis=1)]
        r = len(h_reduced)

        for i in progress_reporter.iterate(
            iterable=range(2**r),
            desc=(
                f"Brute force WEP calc for [[{self.n}, {self.k}]] tensor "
                f"{self.tensor_id} - {r} generators"
            ),
            total_size=2**r,
        ):
            picked_generators = GF2(list(np.binary_repr(i, width=r)), dtype=int)
            if r == 0:
                if i > 0:
                    continue
                stabilizer = GF2.Zeros(self.n * 2)
            else:
                stabilizer = picked_generators @ h_reduced

            collector.collect(stabilizer)
        collector.finalize()
        return collector.tensor_wep

    def stabilizer_enumerator_polynomial(
        self,
        open_legs: Sequence[TensorLeg] = (),
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
        truncate_length: Optional[int] = None,
    ) -> Union[TensorEnumerator, UnivariatePoly]:
        """Compute the stabilizer enumerator polynomial.

        Note that this is a brute force method, and is not efficient for large codes, use it with
        the [planqtn.progress_reporter.TqdmProgressReporter][] to get time estimates.
        If open_legs is empty, returns the scalar stabilizer enumerator polynomial.
        If open_legs is not empty, returns a sparse tensor with non-zero values on
        the open legs.

        Args:
            open_legs: List of legs to leave open.
            verbose: Whether to print verbose output.
            progress_reporter: Progress reporter to use.
            truncate_length: Maximum weight to truncate the enumerator at.

        Returns:
            wep: The stabilizer weight enumerator polynomial.
        """
        wep = self._brute_force_stabilizer_enumerator_from_parity(
            open_legs=open_legs,
            verbose=verbose,
            progress_reporter=progress_reporter,
            truncate_length=truncate_length,
        )
        return wep

    def trace_with_stopper(
        self, stopper: GF2, traced_leg: int | TensorLeg
    ) -> "StabilizerCodeTensorEnumerator":
        """Trace this tensor with a stopper tensor on the specified leg.

        Contracts this tensor with a stopper tensor (representing a measurement
        or boundary condition) on the specified leg.

        Args:
            stopper: The stopper tensor to contract with (as a 1x2 GF2 matrix).
            traced_leg: The leg to contract with the stopper.

        Returns:
            StabilizerCodeTensorEnumerator: New tensor with the stopper contraction applied.
        """
        res = self.conjoin(
            StabilizerCodeTensorEnumerator(stopper, tensor_id="stopper"),
            [traced_leg],
            [0],
        )
        res.annotation = self.annotation
        return res

__init__(h, tensor_id=0, legs=None, coset_flipped_legs=None, annotation=None)

Construct a stabilizer code tensor enumerator.

A StabilizerCodeTensorEnumerator is basically an object oriented wrapper around a parity check matrix. It supports self-tracing, as well as tensor product, and conjoining of with other StabilizerCodeTensorEnumerator instances. As such, it is the building block of tensor networks in the TensorNetwork class.

The class also supports the enumeration of the scalar stabilizer weight enumerator of the code via brute force. There can be legs left open, in which case the weight enumerator becomes a tensor weight enumerator. Weight truncation is supported for approximate enumeration. Coset support is represented by coset_flipped_legs.

Parameters:

Name Type Description Default
h GF2

The parity check matrix.

required
tensor_id TensorId

The ID of the tensor.

0
legs Optional[List[TensorLeg]]

The legs of the tensor.

None
coset_flipped_legs Optional[List[Tuple[Tuple[Any, int], GF2]]]

The coset flipped legs of the tensor.

None
annotation Optional[LegoAnnotation]

The annotation of the tensor for hints for visualization in PlanqTN Studio.

None

Raises:

Type Description
AssertionError

If the legs are not valid.

Source code in planqtn/stabilizer_tensor_enumerator.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def __init__(
    self,
    h: GF2,
    tensor_id: TensorId = 0,
    legs: Optional[List[TensorLeg]] = None,
    coset_flipped_legs: Optional[List[Tuple[Tuple[Any, int], GF2]]] = None,
    annotation: Optional[LegoAnnotation] = None,
):
    """Construct a stabilizer code tensor enumerator.

    A `StabilizerCodeTensorEnumerator` is basically an object oriented wrapper around
    a parity check matrix. It supports self-tracing, as well as tensor product, and conjoining
    of with other `StabilizerCodeTensorEnumerator` instances. As such, it is the building block
    of tensor networks in the [TensorNetwork][planqtn.tensor_network.TensorNetwork] class.

    The class also supports the enumeration of the scalar stabilizer weight enumerator of the
    code via brute force. There can be legs left open, in which case the weight enumerator
    becomes a tensor weight enumerator. Weight truncation is supported for approximate
    enumeration. Coset support is represented by `coset_flipped_legs`.

    Args:
        h: The parity check matrix.
        tensor_id: The ID of the tensor.
        legs: The legs of the tensor.
        coset_flipped_legs: The coset flipped legs of the tensor.
        annotation: The annotation of the tensor for hints for visualization in PlanqTN Studio.

    Raises:
        AssertionError: If the legs are not valid.
    """
    self.h = h
    self.annotation = annotation

    self.tensor_id = tensor_id
    if len(self.h.shape) == 1:
        self.n = self.h.shape[0] // 2
        self.k = self.n - 1
    else:
        self.n = self.h.shape[1] // 2
        self.k = self.n - self.h.shape[0]

    self.legs = (
        [(self.tensor_id, leg) for leg in range(self.n)] if legs is None else legs
    )
    # print(f"Legs: {self.legs} because n = {self.n}, {self.h.shape}")
    assert (
        len(self.legs) == self.n
    ), f"Number of legs {len(self.legs)} != qubit count {self.n} for h: {self.h}"
    # a dict is a wonky tensor - TODO: rephrase this to proper tensor
    self._stabilizer_enums: Dict[sympy.Tuple, UnivariatePoly] = {}

    self.coset_flipped_legs = []
    if coset_flipped_legs is not None:
        self.coset_flipped_legs = coset_flipped_legs
        for leg, pauli in self.coset_flipped_legs:
            assert (
                leg in self.legs
            ), f"Leg in coset not found: {leg} - legs: {self.legs}"
            assert len(pauli) == 2 and isinstance(
                pauli, GF2
            ), f"Invalid pauli in coset: {pauli} on leg {leg}"

conjoin(other, legs1, legs2)

Creates a new tensor enumerator by conjoining two of them.

Creates a new tensor enumerator by contracting the specified legs between this tensor and another tensor. The legs of the other tensor will become the legs of the new tensor.

Parameters:

Name Type Description Default
other StabilizerCodeTensorEnumerator

The other tensor enumerator to conjoin with.

required
legs1 Sequence[int | TensorLeg]

Legs from this tensor to contract.

required
legs2 Sequence[int | TensorLeg]

Legs from the other tensor to contract.

required

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

The conjoined tensor enumerator.

Source code in planqtn/stabilizer_tensor_enumerator.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def conjoin(
    self,
    other: "StabilizerCodeTensorEnumerator",
    legs1: Sequence[int | TensorLeg],
    legs2: Sequence[int | TensorLeg],
) -> "StabilizerCodeTensorEnumerator":
    """Creates a new tensor enumerator by conjoining two of them.

    Creates a new tensor enumerator by contracting the specified legs between
    this tensor and another tensor. The legs of the other tensor will become
    the legs of the new tensor.

    Args:
        other: The other tensor enumerator to conjoin with.
        legs1: Legs from this tensor to contract.
        legs2: Legs from the other tensor to contract.

    Returns:
        StabilizerCodeTensorEnumerator: The conjoined tensor enumerator.
    """
    if self.tensor_id == other.tensor_id:
        return self.self_trace(legs1, legs2)
    assert len(legs1) == len(legs2)
    legs1_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs1)
    legs2_indexed: List[TensorLeg] = _index_legs(other.tensor_id, legs2)

    leg2col = {leg: i for i, leg in enumerate(self.legs)}
    # for example 2 3 4 | 2 4 8 will become
    # as legs2_offset = 5
    # {2: 0, 3: 1, 4: 2, 7: 3, 11: 4, 13: 5}
    leg2col.update({leg: len(self.legs) + i for i, leg in enumerate(other.legs)})

    new_h = conjoin(
        self.h,
        other.h,
        self.legs.index(legs1_indexed[0]),
        other.legs.index(legs2_indexed[0]),
    )
    self._remove_legs(leg2col, [legs1_indexed[0], legs2_indexed[0]])

    for leg1, leg2 in zip(legs1_indexed[1:], legs2_indexed[1:]):
        new_h = self_trace(new_h, leg2col[leg1], leg2col[leg2])
        self._remove_legs(leg2col, [leg1, leg2])

    new_legs = [leg for leg in self.legs if leg not in legs1_indexed]
    new_legs += [leg for leg in other.legs if leg not in legs2_indexed]

    return StabilizerCodeTensorEnumerator(
        new_h, tensor_id=self.tensor_id, legs=new_legs
    )

is_stabilizer(op)

Check if an operator is a stabilizer of this code.

Determines whether the given operator commutes with all stabilizers of the code by checking if op * omega * h^T = 0.

Parameters:

Name Type Description Default
op GF2

Operator to check (as GF2 vector).

required

Returns:

Name Type Description
bool bool

True if op is a stabilizer, False otherwise.

Source code in planqtn/stabilizer_tensor_enumerator.py
213
214
215
216
217
218
219
220
221
222
223
224
225
def is_stabilizer(self, op: GF2) -> bool:
    """Check if an operator is a stabilizer of this code.

    Determines whether the given operator commutes with all stabilizers
    of the code by checking if op * omega * h^T = 0.

    Args:
        op: Operator to check (as GF2 vector).

    Returns:
        bool: True if op is a stabilizer, False otherwise.
    """
    return 0 == np.count_nonzero(op @ omega(self.n) @ self.h.T)

self_trace(legs1, legs2)

Perform self-tracing by contracting pairs of legs within this tensor.

Contracts pairs of legs within the same tensor, effectively performing a partial trace operation. The legs are paired up and contracted together.

Parameters:

Name Type Description Default
legs1 Sequence[int | TensorLeg]

First set of legs to contract (must match length of legs2).

required
legs2 Sequence[int | TensorLeg]

Second set of legs to contract (must match length of legs1).

required

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

New tensor with contracted legs removed.

Raises:

Type Description
AssertionError

If legs1 and legs2 have different lengths.

Source code in planqtn/stabilizer_tensor_enumerator.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def self_trace(
    self, legs1: Sequence[int | TensorLeg], legs2: Sequence[int | TensorLeg]
) -> "StabilizerCodeTensorEnumerator":
    """Perform self-tracing by contracting pairs of legs within this tensor.

    Contracts pairs of legs within the same tensor, effectively performing
    a partial trace operation. The legs are paired up and contracted together.

    Args:
        legs1: First set of legs to contract (must match length of legs2).
        legs2: Second set of legs to contract (must match length of legs1).

    Returns:
        StabilizerCodeTensorEnumerator: New tensor with contracted legs removed.

    Raises:
        AssertionError: If legs1 and legs2 have different lengths.
    """
    assert len(legs1) == len(legs2)
    legs1_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs1)
    legs2_indexed: List[TensorLeg] = _index_legs(self.tensor_id, legs2)
    leg2col = {leg: i for i, leg in enumerate(self.legs)}

    new_h = self.h
    for leg1, leg2 in zip(legs1_indexed, legs2_indexed):
        new_h = self_trace(new_h, leg2col[leg1], leg2col[leg2])
        self._remove_legs(leg2col, [leg1, leg2])

    new_legs = [
        leg
        for leg in self.legs
        if leg not in legs1_indexed and leg not in legs2_indexed
    ]
    return StabilizerCodeTensorEnumerator(
        new_h, tensor_id=self.tensor_id, legs=new_legs
    )

set_tensor_id(tensor_id)

Set the tensor ID and update all legs to use the new ID.

Updates the tensor_id attribute and modifies all legs that reference the old tensor_id to use the new one.

Parameters:

Name Type Description Default
tensor_id TensorId

New tensor ID to assign to this tensor.

required
Source code in planqtn/stabilizer_tensor_enumerator.py
196
197
198
199
200
201
202
203
204
205
206
207
208
def set_tensor_id(self, tensor_id: TensorId) -> None:
    """Set the tensor ID and update all legs to use the new ID.

    Updates the tensor_id attribute and modifies all legs that reference
    the old tensor_id to use the new one.

    Args:
        tensor_id: New tensor ID to assign to this tensor.
    """
    for l, leg in enumerate(self.legs):
        if leg[0] == self.tensor_id:
            self.legs[l] = (tensor_id, leg[1])
    self.tensor_id = tensor_id

stabilizer_enumerator_polynomial(open_legs=(), verbose=False, progress_reporter=DummyProgressReporter(), truncate_length=None)

Compute the stabilizer enumerator polynomial.

Note that this is a brute force method, and is not efficient for large codes, use it with the planqtn.progress_reporter.TqdmProgressReporter to get time estimates. If open_legs is empty, returns the scalar stabilizer enumerator polynomial. If open_legs is not empty, returns a sparse tensor with non-zero values on the open legs.

Parameters:

Name Type Description Default
open_legs Sequence[TensorLeg]

List of legs to leave open.

()
verbose bool

Whether to print verbose output.

False
progress_reporter ProgressReporter

Progress reporter to use.

DummyProgressReporter()
truncate_length Optional[int]

Maximum weight to truncate the enumerator at.

None

Returns:

Name Type Description
wep Union[TensorEnumerator, UnivariatePoly]

The stabilizer weight enumerator polynomial.

Source code in planqtn/stabilizer_tensor_enumerator.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def stabilizer_enumerator_polynomial(
    self,
    open_legs: Sequence[TensorLeg] = (),
    verbose: bool = False,
    progress_reporter: ProgressReporter = DummyProgressReporter(),
    truncate_length: Optional[int] = None,
) -> Union[TensorEnumerator, UnivariatePoly]:
    """Compute the stabilizer enumerator polynomial.

    Note that this is a brute force method, and is not efficient for large codes, use it with
    the [planqtn.progress_reporter.TqdmProgressReporter][] to get time estimates.
    If open_legs is empty, returns the scalar stabilizer enumerator polynomial.
    If open_legs is not empty, returns a sparse tensor with non-zero values on
    the open legs.

    Args:
        open_legs: List of legs to leave open.
        verbose: Whether to print verbose output.
        progress_reporter: Progress reporter to use.
        truncate_length: Maximum weight to truncate the enumerator at.

    Returns:
        wep: The stabilizer weight enumerator polynomial.
    """
    wep = self._brute_force_stabilizer_enumerator_from_parity(
        open_legs=open_legs,
        verbose=verbose,
        progress_reporter=progress_reporter,
        truncate_length=truncate_length,
    )
    return wep

tensor_with(other)

Create the tensor product with another tensor enumerator.

Computes the tensor product of this tensor with another tensor enumerator. The resulting tensor has the combined parity check matrix and all legs from both tensors.

Parameters:

Name Type Description Default
other StabilizerCodeTensorEnumerator

The other tensor enumerator to tensor with.

required

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

The tensor product of the two tensors.

Source code in planqtn/stabilizer_tensor_enumerator.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def tensor_with(
    self, other: "StabilizerCodeTensorEnumerator"
) -> "StabilizerCodeTensorEnumerator":
    """Create the tensor product with another tensor enumerator.

    Computes the tensor product of this tensor with another tensor enumerator.
    The resulting tensor has the combined parity check matrix and all legs
    from both tensors.

    Args:
        other: The other tensor enumerator to tensor with.

    Returns:
        StabilizerCodeTensorEnumerator: The tensor product of the two tensors.
    """
    new_h = tensor_product(self.h, other.h)
    if np.array_equal(new_h, GF2([[0]])):
        return StabilizerCodeTensorEnumerator(
            new_h, tensor_id=self.tensor_id, legs=[]
        )
    return StabilizerCodeTensorEnumerator(
        new_h, tensor_id=self.tensor_id, legs=self.legs + other.legs
    )

trace_with_stopper(stopper, traced_leg)

Trace this tensor with a stopper tensor on the specified leg.

Contracts this tensor with a stopper tensor (representing a measurement or boundary condition) on the specified leg.

Parameters:

Name Type Description Default
stopper GF2

The stopper tensor to contract with (as a 1x2 GF2 matrix).

required
traced_leg int | TensorLeg

The leg to contract with the stopper.

required

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

New tensor with the stopper contraction applied.

Source code in planqtn/stabilizer_tensor_enumerator.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def trace_with_stopper(
    self, stopper: GF2, traced_leg: int | TensorLeg
) -> "StabilizerCodeTensorEnumerator":
    """Trace this tensor with a stopper tensor on the specified leg.

    Contracts this tensor with a stopper tensor (representing a measurement
    or boundary condition) on the specified leg.

    Args:
        stopper: The stopper tensor to contract with (as a 1x2 GF2 matrix).
        traced_leg: The leg to contract with the stopper.

    Returns:
        StabilizerCodeTensorEnumerator: New tensor with the stopper contraction applied.
    """
    res = self.conjoin(
        StabilizerCodeTensorEnumerator(stopper, tensor_id="stopper"),
        [traced_leg],
        [0],
    )
    res.annotation = self.annotation
    return res

with_coset_flipped_legs(coset_flipped_legs)

Create a new tensor enumerator with coset-flipped legs.

Creates a copy of this tensor enumerator with the specified coset-flipped legs. This is used for coset weight enumerator calculations.

Parameters:

Name Type Description Default
coset_flipped_legs List[Tuple[TensorLeg, GF2]]

List of (leg, coset_error) pairs specifying which legs have coset errors applied.

required

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

New tensor enumerator with coset-flipped legs.

Source code in planqtn/stabilizer_tensor_enumerator.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def with_coset_flipped_legs(
    self, coset_flipped_legs: List[Tuple[TensorLeg, GF2]]
) -> "StabilizerCodeTensorEnumerator":
    """Create a new tensor enumerator with coset-flipped legs.

    Creates a copy of this tensor enumerator with the specified coset-flipped
    legs. This is used for coset weight enumerator calculations.

    Args:
        coset_flipped_legs: List of (leg, coset_error) pairs specifying
            which legs have coset errors applied.

    Returns:
        StabilizerCodeTensorEnumerator: New tensor enumerator with coset-flipped legs.
    """
    return StabilizerCodeTensorEnumerator(
        self.h, self.tensor_id, self.legs, coset_flipped_legs
    )

TensorNetwork

A tensor network for contracting stabilizer code tensor enumerators.

Source code in planqtn/tensor_network.py
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
class TensorNetwork:
    """A tensor network for contracting stabilizer code tensor enumerators."""

    def __init__(
        self,
        nodes: Union[
            Iterable[StabilizerCodeTensorEnumerator],
            Dict[TensorId, StabilizerCodeTensorEnumerator],
        ],
        truncate_length: Optional[int] = None,
    ):
        """Construct a tensor network.

        This class represents a tensor network composed of
        [`StabilizerCodeTensorEnumerator`][planqtn.StabilizerCodeTensorEnumerator]
        nodes that can be contracted together to compute stabilizer enumerator polynomials.
        The trace ordering can be left to use the original manual ordering or use automated,
        hyperoptimized contraction ordering using the `cotengra` library.

        The tensor network maintains a collection of nodes (tensors) and traces (contraction
        operations between nodes). It can compute weight enumerator polynomials for
        stabilizer codes by contracting the network according to the specified traces.

        Args:
            nodes: Dictionary mapping tensor IDs to
                [`StabilizerCodeTensorEnumerator`][planqtn.StabilizerCodeTensorEnumerator] objects.
            truncate_length: Optional maximum length for truncating enumerator polynomials.

        Raises:
            ValueError: If the nodes have inconsistent indexing.
            ValueError: If there are colliding index values in the nodes.
        """
        if isinstance(nodes, dict):
            for k, v in nodes.items():
                if k != v.tensor_id:
                    raise ValueError(
                        f"Nodes dict passed in with inconsitent indexing, "
                        f"{k} != {v.tensor_id} for {v}."
                    )
            self.nodes: Dict[TensorId, StabilizerCodeTensorEnumerator] = nodes
        else:
            nodes_dict = {node.tensor_id: node for node in nodes}
            if len(nodes_dict) < len(list(nodes)):
                raise ValueError(f"There are colliding index values of nodes: {nodes}")
            self.nodes = nodes_dict

        self._traces: List[Trace] = []
        self._cot_tree = None
        self._cot_traces: Optional[List[Trace]] = None

        self._legs_left_to_join: Dict[TensorId, List[TensorLeg]] = {
            idx: [] for idx in self.nodes.keys()
        }
        # self.open_legs = [n.legs for n in self.nodes]

        self._wep: Optional[TensorEnumerator | UnivariatePoly] = None
        self._ptes: Dict[TensorId, _PartiallyTracedEnumerator] = {}
        self._coset: Optional[GF2] = None
        self.truncate_length: Optional[int] = truncate_length

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, TensorNetwork):
            return False

        # Compare nodes
        if set(self.nodes.keys()) != set(other.nodes.keys()):
            return False

        for idx in self.nodes:
            if (self.nodes[idx].h != other.nodes[idx].h).any():
                return False
            if self.nodes[idx].legs != other.nodes[idx].legs:
                return False
            if (
                self.nodes[idx].coset_flipped_legs
                != other.nodes[idx].coset_flipped_legs
            ):
                return False

        # Compare traces - convert only the hashable parts to tuples
        def trace_to_comparable(
            trace: Trace,
        ) -> Tuple[TensorId, TensorId, Tuple[TensorLeg, ...], Tuple[TensorLeg, ...]]:
            node_idx1, node_idx2, join_legs1, join_legs2 = trace
            return (node_idx1, node_idx2, tuple(join_legs1), tuple(join_legs2))

        self_traces = {trace_to_comparable(t) for t in self._traces}
        other_traces = {trace_to_comparable(t) for t in other._traces}

        if self_traces != other_traces:
            return False

        return True

    def __hash__(self) -> int:
        # Hash the nodes
        nodes_hash = 0
        for idx in sorted(self.nodes.keys()):
            node = self.nodes[idx]
            nodes_hash ^= hash(
                (
                    idx,
                    tuple(map(tuple, node.h)),
                    tuple(node.legs),
                    (
                        tuple(map(tuple, node.coset_flipped_legs))
                        if node.coset_flipped_legs
                        else None
                    ),
                )
            )

        # Hash the traces - convert only the hashable parts to tuples
        def trace_to_hashable(
            trace: Trace,
        ) -> Tuple[TensorId, TensorId, Tuple[TensorLeg, ...], Tuple[TensorLeg, ...]]:
            node_idx1, node_idx2, join_legs1, join_legs2 = trace
            return (node_idx1, node_idx2, tuple(join_legs1), tuple(join_legs2))

        traces_hash = hash(tuple(sorted(trace_to_hashable(t) for t in self._traces)))

        return nodes_hash ^ traces_hash

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        This method maps a global qubit index to the specific node and leg
        that represents that qubit in the tensor network. This is an abstract method
        that must be implemented by subclasses that have a representation for qubits.

        Args:
            q: Global qubit index.

        Returns:
            node_id: Node ID and leg that represent the qubit.
            leg: Leg that represent the qubit.


        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """  # noqa: DAR202
        raise NotImplementedError(
            f"qubit_to_node_and_leg() is not implemented for {type(self)}!"
        )

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns the total number of qubits represented by this tensor network. This is an abstract
        method that must be implemented by subclasses that have a representation for qubits.

        Returns:
            int: Total number of qubits.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """  # noqa: DAR202
        raise NotImplementedError(f"n_qubits() is not implemented for {type(self)}")

    def _reset_wep(self, keep_cot: bool = False) -> None:

        self._wep = None

        prev_traces = deepcopy(self._traces)
        self._traces = []
        self._legs_left_to_join = {idx: [] for idx in self.nodes.keys()}

        for trace in prev_traces:
            self.self_trace(trace[0], trace[1], [trace[2][0]], [trace[3][0]])

        self._ptes = {}
        self._coset = GF2.Zeros(2 * self.n_qubits())

        if keep_cot:
            self._cot_tree = None
            self._cot_traces = None

    def set_coset(self, coset_error: GF2 | Tuple[List[int], List[int]]) -> None:
        """Set the coset error for the tensor network.

        Sets the coset error that will be used for coset weight enumerator calculations.
        The coset error should follow the qubit numbering defined in
         [`qubit_to_node_and_leg`][planqtn.TensorNetwork.qubit_to_node_and_leg] which maps the index
        to a node ID. Both [`qubit_to_node_and_leg`][planqtn.TensorNetwork.qubit_to_node_and_leg]
        and [`n_qubits`][planqtn.TensorNetwork.n_qubits] are abstract classes, and thus this method
        can only be called on a subclass that implements these methods, see the
        [`planqtn.networks`][planqtn.networks] module for examples.

        There are two possible ways to pass the coset_error:

        - a tuple of two lists of qubit indices, one for the `Z` errors and one for the `X` errors
        - a `galois.GF2` array of length `2 * tn.n_qubits()` for the `tn` tensor network. This is a
            symplectic operator representation on the `n` qubits of the tensor network.

        Args:
            coset_error: The coset error specification.

        Raises:
            ValueError: If the coset error has the wrong number of qubits.
        """
        self._reset_wep(keep_cot=True)

        self._coset = GF2.Zeros(2 * self.n_qubits())

        if isinstance(coset_error, tuple):
            for i in coset_error[0]:
                self._coset[i] = 1
            for i in coset_error[1]:
                self._coset[i + self.n_qubits()] = 1
        elif isinstance(coset_error, GF2):
            self._coset = coset_error

        n = len(self._coset) // 2
        if n != self.n_qubits():
            raise ValueError(
                f"Can't set coset with {n} qubits for a {self.n_qubits()} qubit code."
            )

        z_errors = np.argwhere(self._coset[n:] == 1).flatten()
        x_errors = np.argwhere(self._coset[:n] == 1).flatten()

        node_legs_to_flip = defaultdict(list)

        for q in range(n):
            is_z = q in z_errors
            is_x = q in x_errors
            node_idx, leg = self.qubit_to_node_and_leg(q)

            self.nodes[node_idx].coset_flipped_legs = []
            if not is_z and not is_x:
                continue
            # print(f"q{q} -> {node_idx, leg}")
            node_legs_to_flip[node_idx].append((leg, GF2([is_x, is_z])))

        for node_idx, coset_flipped_legs in node_legs_to_flip.items():

            # print(node_idx, f" will have flipped {coset_flipped_legs}")

            self.nodes[node_idx] = self.nodes[node_idx].with_coset_flipped_legs(
                coset_flipped_legs
            )

    def self_trace(
        self,
        node_idx1: TensorId,
        node_idx2: TensorId,
        join_legs1: Sequence[int | TensorLeg],
        join_legs2: Sequence[int | TensorLeg],
    ) -> None:
        """Add a trace operation between two nodes in the tensor network.

        Defines a contraction between two nodes by specifying which legs to join.
        This operation is added to the trace schedule and will be executed when
        the tensor network is contracted.

        Args:
            node_idx1: ID of the first node to trace.
            node_idx2: ID of the second node to trace.
            join_legs1: Legs from the first node to contract.
            join_legs2: Legs from the second node to contract.

        Raises:
            ValueError: If the weight enumerator has already been computed.
        """
        if self._wep is not None:
            raise ValueError(
                "Tensor network weight enumerator is already traced no new tracing schedule is "
                "allowed."
            )
        join_legs1_indexed = _index_legs(node_idx1, join_legs1)
        join_legs2_indexed = _index_legs(node_idx2, join_legs2)

        # print(f"adding trace {node_idx1, node_idx2, join_legs1, join_legs2}")
        self._traces.append(
            (node_idx1, node_idx2, join_legs1_indexed, join_legs2_indexed)
        )

        self._legs_left_to_join[node_idx1] += join_legs1_indexed
        self._legs_left_to_join[node_idx2] += join_legs2_indexed

    def traces_to_dot(self) -> None:
        """Print the tensor network traces in DOT format.

        Prints the traces (contractions) between nodes in a format that can be
        used to visualize the tensor network structure. Each trace is printed
        as a directed edge between nodes.
        """
        print("-----")
        # print(self.open_legs)
        # for n, legs in enumerate(self.open_legs):
        #     for leg in legs:
        #         print(f"n{n} -> n{n}_{leg}")

        for node_idx1, node_idx2, join_legs1, join_legs2 in self._traces:
            for _ in zip(join_legs1, join_legs2):
                print(f"n{node_idx1} -> n{node_idx2} ")

    def _cotengra_tree_from_traces(
        self,
        free_legs: List[TensorLeg],
        leg_indices: Dict[TensorLeg, str],
    ) -> ctg.ContractionTree:
        inputs, output, size_dict, input_names = self._prep_cotengra_inputs(
            leg_indices, free_legs, True
        )

        path = []
        terms = [{node_idx} for node_idx in input_names]

        def idx(node_id: TensorId) -> int:
            for i, term in enumerate(terms):
                if node_id in term:
                    return i
            assert False, (
                "This should not happen, nodes should be always present in at least one of the "
                "terms."
            )

        for node_idx1, node_idx2, _, _ in self._traces:
            i, j = sorted([idx(node_idx1), idx(node_idx2)])
            # print((node_idx1, node_idx2), f"=> {i,j}", terms)
            if i == j:
                continue
            path.append({i, j})
            term2 = terms.pop(j)
            term1 = terms.pop(i)
            terms.append(term1.union(term2))
        return ctg.ContractionTree.from_path(
            inputs, output, size_dict, path=path, check=True
        )

    def analyze_traces(
        self,
        cotengra: bool = False,
        each_step: bool = False,
        details: bool = False,
        **cotengra_opts: Any,
    ) -> Tuple[ctg.ContractionTree, int]:
        """Analyze the trace operations and optionally optimize the contraction path.

        Analyzes the current trace schedule and can optionally use cotengra to
        find an optimal contraction path. This is useful for understanding the
        computational complexity of the tensor network contraction.

        Args:
            cotengra: If True, use cotengra to optimize the contraction path.
            each_step: If True, print details for each contraction step.
            details: If True, print detailed analysis information.
            **cotengra_opts: Additional options to pass to cotengra.

        Returns:
            Tuple[ctg.ContractionTree, int]: The contraction tree and total cost.
        """
        free_legs, leg_indices, index_to_legs = self._collect_legs()
        tree = None

        node_to_free_legs = defaultdict(list)
        for leg in free_legs:
            for node_idx, node in self.nodes.items():
                if leg in node.legs:
                    node_to_free_legs[node.tensor_id].append(leg)

        new_tn = TensorNetwork(deepcopy(self.nodes))

        # pylint: disable=W0212
        new_tn._traces = deepcopy(self._traces)
        if cotengra:

            new_tn._traces, tree = self._cotengra_contraction(
                free_legs,
                leg_indices,
                index_to_legs,
                details,
                TqdmProgressReporter() if details else DummyProgressReporter(),
                **cotengra_opts,
            )
        else:
            tree = self._cotengra_tree_from_traces(free_legs, leg_indices)

        # pylint: disable=W0212
        new_tn._legs_left_to_join = deepcopy(self._legs_left_to_join)

        pte_nodes: Dict[TensorId, int] = {}
        max_pte_legs = 0
        if details:
            print(
                "========================== ======= === === === == ==============================="
            )
            print(
                "========================== TRACE SCHEDULE ANALYSIS ============================="
            )
            print(
                "========================== ======= === === === == ==============================="
            )
            print(
                f"    Total legs to trace: "
                f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
            )
        pte_leg_numbers: Dict[TensorId, int] = defaultdict(int)

        for node_idx1, node_idx2, join_legs1, join_legs2 in new_tn._traces:
            if each_step:
                print(
                    f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
                )

            for leg in join_legs1:
                new_tn._legs_left_to_join[node_idx1].remove(leg)
            for leg in join_legs2:
                new_tn._legs_left_to_join[node_idx2].remove(leg)

            if node_idx1 not in pte_nodes and node_idx2 not in pte_nodes:
                next_pte = 0 if len(pte_nodes) == 0 else max(pte_nodes.values()) + 1
                if each_step:
                    print(f"New PTE: {next_pte}")
                pte_nodes[node_idx1] = next_pte
                pte_nodes[node_idx2] = next_pte
            elif node_idx1 in pte_nodes and node_idx2 not in pte_nodes:
                pte_nodes[node_idx2] = pte_nodes[node_idx1]
            elif node_idx2 in pte_nodes and node_idx1 not in pte_nodes:
                pte_nodes[node_idx1] = pte_nodes[node_idx2]
            elif pte_nodes[node_idx1] == pte_nodes[node_idx2]:
                if each_step:
                    print(f"self trace in PTE {pte_nodes[node_idx1]}")
            else:
                if each_step:
                    print(f"MERGE of {pte_nodes[node_idx1]} and {pte_nodes[node_idx2]}")
                removed_pte = pte_nodes[node_idx2]
                merged_pte = pte_nodes[node_idx1]
                for node_idx, pte_node in pte_nodes.items():
                    if pte_node == removed_pte:
                        pte_nodes[node_idx] = merged_pte

            if details:
                print(f"    pte nodes: {pte_nodes}")
            if each_step:
                print(
                    f"    Total legs to trace: "
                    f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
                )

            pte_leg_numbers = defaultdict(int)

            for node_idx, pte_node in pte_nodes.items():
                pte_leg_numbers[pte_node] += len(new_tn._legs_left_to_join[node_idx])

            if each_step:
                print(f"     PTEs num tracable legs: {dict(pte_leg_numbers)}")

            biggest_legs = max(pte_leg_numbers.values())

            max_pte_legs = max(max_pte_legs, biggest_legs)
            if each_step:
                print(f"    Biggest PTE legs: {biggest_legs} vs MAX: {max_pte_legs}")
        if details:
            print("=== Final state ==== ")
            print(f"pte nodes: {pte_nodes}")

            print(
                f"all nodes {set(pte_nodes.keys()) == set(new_tn.nodes.keys())} "
                f"and all nodes are in a single PTE: {len(set(pte_nodes.values())) == 1}"
            )
            print(
                f"Total legs to trace: "
                f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
            )
            print(f"PTEs num tracable legs: {dict(pte_leg_numbers)}")
            print(f"Maximum PTE legs: {max_pte_legs}")
        return tree, max_pte_legs

    def conjoin_nodes(
        self,
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
    ) -> "StabilizerCodeTensorEnumerator":
        """Conjoin all nodes in the tensor network according to the trace schedule.

        Executes all the trace operations defined in the tensor network to produce
        a single tensor enumerator. This tensor enumerator will have the conjoined parity check
        matrix. However, running weight enumerator calculation on this conjoined node would use the
        brute force method, and as such would be typically more expensive than using the
        [`stabilizer_enumerator_polynomial`][planqtn.TensorNetwork.stabilizer_enumerator_polynomial]
        method.

        Args:
            verbose: If True, print verbose output during contraction.
            progress_reporter: Progress reporter for tracking the contraction process.

        Returns:
            StabilizerCodeTensorEnumerator: The contracted tensor enumerator.
        """
        # If there's only one node and no traces, return it directly
        if len(self.nodes) == 1 and len(self._traces) == 0:
            return list(self.nodes.values())[0]

        # Map from node_idx to the index of its PTE in ptes list
        nodes = list(self.nodes.values())
        ptes: List[Tuple[StabilizerCodeTensorEnumerator, Set[TensorId]]] = [
            (node, {node.tensor_id}) for node in nodes
        ]
        node_to_pte = {node.tensor_id: i for i, node in enumerate(nodes)}

        for node_idx1, node_idx2, join_legs1, join_legs2 in progress_reporter.iterate(
            self._traces, "Conjoining nodes", len(self._traces)
        ):
            if verbose:
                print(
                    f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
                )

            join_legs1 = _index_legs(node_idx1, join_legs1)
            join_legs2 = _index_legs(node_idx2, join_legs2)

            pte1_idx = node_to_pte[node_idx1]
            pte2_idx = node_to_pte[node_idx2]

            # Case 1: Both nodes are in the same PTE
            if pte1_idx == pte2_idx:
                if verbose:
                    print(
                        f"Self trace in PTE containing both {node_idx1} and {node_idx2}"
                    )
                pte, nodes_in_pte = ptes[pte1_idx]
                new_pte = pte.self_trace(join_legs1, join_legs2)
                ptes[pte1_idx] = (new_pte, nodes_in_pte)

            # Case 2: Nodes are in different PTEs - merge them
            else:
                if verbose:
                    print(f"Merging PTEs containing {node_idx1} and {node_idx2}")
                pte1, nodes1 = ptes[pte1_idx]
                pte2, nodes2 = ptes[pte2_idx]
                new_pte = pte1.conjoin(pte2, legs1=join_legs1, legs2=join_legs2)
                merged_nodes = nodes1.union(nodes2)

                # Update the first PTE with merged result
                ptes[pte1_idx] = (new_pte, merged_nodes)
                # Remove the second PTE
                ptes.pop(pte2_idx)

                # Update node_to_pte mappings
                for node_idx in nodes2:
                    node_to_pte[node_idx] = pte1_idx
                # Adjust indices for all nodes in PTEs after the removed one
                for node_idx, pte_idx in node_to_pte.items():
                    if pte_idx > pte2_idx:
                        node_to_pte[node_idx] = pte_idx - 1

            if verbose:
                print("H:")
                sprint(ptes[0][0].h)

        # If we have multiple components at the end, tensor them together
        if len(ptes) > 1:
            for other in ptes[1:]:
                ptes[0] = (ptes[0][0].tensor_with(other[0]), ptes[0][1].union(other[1]))

        return ptes[0][0]

    def _collect_legs(
        self,
    ) -> Tuple[
        List[TensorLeg],
        Dict[TensorLeg, str],
        Dict[str, List[Tuple[TensorId, TensorLeg]]],
    ]:
        leg_indices = {}
        index_to_legs = {}
        current_index = 0
        free_legs = []
        # Iterate over each node in the tensor network
        for node_idx, node in self.nodes.items():
            # Iterate over each leg in the node
            for leg in node.legs:
                current_idx_name = f"{leg}"
                # If the leg is already indexed, skip it
                if leg in leg_indices:
                    continue
                # Assign the current index to the leg
                leg_indices[leg] = current_idx_name
                index_to_legs[current_idx_name] = [(node_idx, leg)]
                open_leg = True
                # Check for traces and assign the same index to traced legs
                for node_idx1, node_idx2, join_legs1, join_legs2 in self._traces:
                    idx = -1
                    if leg in join_legs1:
                        idx = join_legs1.index(leg)
                    elif leg in join_legs2:
                        idx = join_legs2.index(leg)
                    else:
                        continue
                    open_leg = False
                    current_idx_name = f"{join_legs1[idx]}_{join_legs2[idx]}"
                    leg_indices[join_legs1[idx]] = current_idx_name
                    leg_indices[join_legs2[idx]] = current_idx_name
                    index_to_legs[current_idx_name] = [
                        (node_idx1, join_legs1[idx]),
                        (node_idx2, join_legs2[idx]),
                    ]
                # Move to the next index
                if open_leg:
                    free_legs.append(leg)
                current_index += 1
        return free_legs, leg_indices, index_to_legs

    def _prep_cotengra_inputs(
        self,
        leg_indices: Dict[TensorLeg, str],
        free_legs: List[TensorLeg],
        verbose: bool = False,
    ) -> Tuple[List[Tuple[str, ...]], List[str], Dict[str, int], List[str]]:
        inputs = []
        output: List[str] = []
        size_dict = {leg: 2 for leg in leg_indices.values()}

        input_names = []

        for node_idx, node in self.nodes.items():
            inputs.append(tuple(leg_indices[leg] for leg in node.legs))
            input_names.append(str(node_idx))
            if verbose:
                # Print the indices for each node
                for leg in node.legs:
                    print(
                        f"  Leg {leg}: Index {leg_indices[leg]} "
                        f"{'OPEN' if leg in free_legs else 'traced'}"
                    )
        if verbose:
            print(input_names)
            print(inputs)
            print(output)
            print(size_dict)
        return inputs, output, size_dict, input_names

    def _traces_from_cotengra_tree(
        self,
        tree: ctg.ContractionTree,
        index_to_legs: Dict[str, List[Tuple[TensorId, TensorLeg]]],
        inputs: List[Tuple[str, ...]],
    ) -> List[Trace]:
        def legs_to_contract(l: frozenset, r: frozenset) -> List[Trace]:
            res = []
            left_indices = sum((list(inputs[leaf_idx]) for leaf_idx in l), [])
            right_indices = sum((list(inputs[leaf_idx]) for leaf_idx in r), [])
            for idx1 in left_indices:
                if idx1 in right_indices:
                    legs = index_to_legs[idx1]
                    res.append((legs[0][0], legs[1][0], [legs[0][1]], [legs[1][1]]))
            return res

        # We convert the tree back to a list of traces
        traces = []
        for _, l, r in tree.traverse():
            # at each step we have to find the nodes that share indices in the two merged subsets
            new_traces = legs_to_contract(l, r)
            traces += new_traces

        trace_indices = []
        for t in traces:
            assert t in self._traces, f"{t} not in traces. Traces: {self._traces}"
            idx = self._traces.index(t)
            trace_indices.append(idx)

        assert set(trace_indices) == set(
            range(len(self._traces))
        ), "Some traces are missing from cotengra tree\n" + "\n".join(
            [
                str(self._traces[i])
                for i in set(range(len(self._traces))) - set(trace_indices)
            ]
        )
        return traces

    def _cotengra_contraction(
        self,
        free_legs: List[TensorLeg],
        leg_indices: Dict[TensorLeg, str],
        index_to_legs: Dict[str, List[Tuple[TensorId, TensorLeg]]],
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
        **cotengra_opts: Any,
    ) -> Tuple[
        List[Trace],
        ctg.ContractionTree,
    ]:

        if self._cot_traces is not None:
            return self._cot_traces, self._cot_tree

        inputs, output, size_dict, _ = self._prep_cotengra_inputs(
            leg_indices, free_legs, verbose
        )

        contengra_params = {
            "minimize": "combo",
            "parallel": False,
            # kahypar is not installed by default, but if user has it they can use it by default
            # otherwise, our default is greedy right now
            "methods": [
                "kahypar" if importlib.util.find_spec("kahypar") else "greedy",
                "labels",
            ],
            "optlib": "cmaes",
        }
        contengra_params.update(cotengra_opts)
        opt = ctg.HyperOptimizer(
            **contengra_params,
            progbar=not isinstance(progress_reporter, DummyProgressReporter),
        )

        self._cot_tree = opt.search(inputs, output, size_dict)

        self._cot_traces = self._traces_from_cotengra_tree(
            self._cot_tree, index_to_legs=index_to_legs, inputs=inputs
        )

        return self._cot_traces, self._cot_tree

    # weight_enumerator_polynomial
    # - pass in a list of bool for each node True: stabilizer False: normalizer

    def stabilizer_enumerator_polynomial(
        self,
        open_legs: Sequence[TensorLeg] = (),
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
        cotengra: bool = True,
    ) -> TensorEnumerator | UnivariatePoly:
        """Returns the reduced stabilizer enumerator polynomial for the tensor network.

        If open_legs is not empty, then the returned tensor enumerator polynomial is a dictionary of
        tensor keys to UnivariatePoly objects.

        Args:
            open_legs: The legs that are open in the tensor network. If empty, the result is a
                       scalar weightenumerator polynomial of type `UnivariatePoly`,otherwise it is a
                       dictionary of `TensorEnumeratorKey` keys to `UnivariatePoly` objects.
            verbose: If True, print verbose output.
            progress_reporter: The progress reporter to use, defaults to no progress reporting
                              (`DummyProgressReporter`), can be set to `TqdmProgressReporter` for
                              progress reporting on the console, or any other custom
                              `ProgressReporter` subclass.
            cotengra: If True, use cotengra to contract the tensor network, otherwise use the order
                      the traces were constructed.

        Returns:
            TensorEnumerator: The reduced stabilizer enumerator polynomial for the tensor network.
        """
        if self._wep is not None:
            return self._wep

        assert (
            progress_reporter is not None
        ), "Progress reporter must be provided, it is None"

        with progress_reporter.enter_phase("collecting legs"):
            free_legs, leg_indices, index_to_legs = self._collect_legs()

        open_legs_per_node = defaultdict(list)
        for node_idx, node in self.nodes.items():
            for leg in node.legs:
                if leg not in free_legs:
                    open_legs_per_node[node_idx].append(_index_leg(node_idx, leg))

        for node_idx, leg_index in open_legs:
            open_legs_per_node[node_idx].append(_index_leg(node_idx, leg_index))

        if verbose:
            print("open_legs_per_node", open_legs_per_node)
        traces = self._traces
        if cotengra and len(self.nodes) > 0 and len(self._traces) > 0:
            with progress_reporter.enter_phase("cotengra contraction"):
                traces, _ = self._cotengra_contraction(
                    free_legs, leg_indices, index_to_legs, verbose, progress_reporter
                )
        summed_legs = [leg for leg in free_legs if leg not in open_legs]

        if len(self._traces) == 0 and len(self.nodes) == 1:
            return list(self.nodes.items())[0][1].stabilizer_enumerator_polynomial(
                verbose=verbose,
                progress_reporter=progress_reporter,
                truncate_length=self.truncate_length,
                open_legs=open_legs,
            )

        # parity_check_enums = {}

        for node_idx, node in self.nodes.items():
            traced_legs = open_legs_per_node[node_idx]
            # TODO: figure out tensor caching
            # traced_leg_indices = "".join(
            #     [str(i) for i in sorted([node.legs.index(leg) for leg in traced_legs])]
            # )
            # hkey = sstr(gauss(node.h)) + ";" + traced_leg_indices
            # if hkey not in parity_check_enums:
            #     parity_check_enums[hkey] = node.stabilizer_enumerator_polynomial(
            #         open_legs=traced_legs
            #     )
            # else:
            #     print("Found one!")
            #     calc = node.stabilizer_enumerator_polynomial(open_legs=traced_legs)
            #     assert (
            #         calc == parity_check_enums[hkey]
            #     ), f"for key {hkey}\n calc\n{calc}\n vs retrieved\n{parity_check_enums[hkey]}"

            # call the right type here...
            tensor = node.stabilizer_enumerator_polynomial(
                open_legs=traced_legs,
                verbose=verbose,
                progress_reporter=progress_reporter,
                truncate_length=self.truncate_length,
            )
            if isinstance(tensor, UnivariatePoly):
                tensor = {(): tensor}
            self._ptes[node_idx] = _PartiallyTracedEnumerator(
                nodes={node_idx},
                tracable_legs=open_legs_per_node[node_idx],
                tensor=tensor,  # deepcopy(parity_check_enums[hkey]),
                truncate_length=self.truncate_length,
            )

        for node_idx1, node_idx2, join_legs1, join_legs2 in progress_reporter.iterate(
            traces, f"Tracing {len(traces)} legs", len(traces)
        ):
            if verbose:
                print(
                    f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
                )
                print(
                    f"Total legs left to join: "
                    f"{sum(len(legs) for legs in self._legs_left_to_join.values())}"
                )
            node1_pte = self._ptes[node_idx1]
            node2_pte = self._ptes[node_idx2]

            # print(f"PTEs: {node1_pte}, {node2_pte}")
            # check that the length of the tensor is a power of 4

            if node1_pte == node2_pte:
                # both nodes are in the same PTE!
                if verbose:
                    print(f"self trace within PTE {node1_pte}")
                pte = node1_pte.self_trace(
                    join_legs1=[
                        (node_idx1, leg) if isinstance(leg, int) else leg
                        for leg in join_legs1
                    ],
                    join_legs2=[
                        (node_idx2, leg) if isinstance(leg, int) else leg
                        for leg in join_legs2
                    ],
                    progress_reporter=progress_reporter,
                    verbose=verbose,
                )
                for node_idx in pte.nodes:
                    self._ptes[node_idx] = pte
                self._legs_left_to_join[node_idx1] = [
                    leg
                    for leg in self._legs_left_to_join[node_idx1]
                    if leg not in join_legs1
                ]
                self._legs_left_to_join[node_idx2] = [
                    leg
                    for leg in self._legs_left_to_join[node_idx2]
                    if leg not in join_legs2
                ]
            else:
                if verbose:
                    print(f"MERGING two components {node1_pte} and {node2_pte}")
                    print(f"node1_pte {node1_pte}:")
                    for k in list(node1_pte.tensor.keys()):
                        v = node1_pte.tensor[k]
                        print(Pauli.to_str(*k), end=" ")
                        print(v)
                    print(f"node2_pte {node2_pte}:")
                    for k in list(node2_pte.tensor.keys()):
                        v = node2_pte.tensor[k]
                        print(Pauli.to_str(*k), end=" ")
                        print(v)
                pte = node1_pte.merge_with(
                    node2_pte,
                    join_legs1=[
                        (node_idx1, leg) if isinstance(leg, int) else leg
                        for leg in join_legs1
                    ],
                    join_legs2=[
                        (node_idx2, leg) if isinstance(leg, int) else leg
                        for leg in join_legs2
                    ],
                    verbose=verbose,
                    progress_reporter=progress_reporter,
                )

                for node_idx in pte.nodes:
                    self._ptes[node_idx] = pte
                self._legs_left_to_join[node_idx1] = [
                    leg
                    for leg in self._legs_left_to_join[node_idx1]
                    if leg not in join_legs1
                ]
                self._legs_left_to_join[node_idx2] = [
                    leg
                    for leg in self._legs_left_to_join[node_idx2]
                    if leg not in join_legs2
                ]

            node1_pte = self._ptes[node_idx1]

            if verbose:
                print(
                    f"PTE nodes: {node1_pte.nodes if node1_pte is not None else None}"
                )
                print(
                    f"PTE tracable legs: "
                    f"{node1_pte.tracable_legs if node1_pte is not None else None}"
                )
            if verbose:
                print("PTE tensor: ")
            for k in list(node1_pte.tensor.keys() if node1_pte is not None else []):
                v = node1_pte.tensor[k] if node1_pte is not None else UnivariatePoly()
                # if not 0 in v:
                #     continue
                if verbose:
                    print(Pauli.to_str(*k), end=" ")
                    print(v, end="")
                if self.truncate_length is None:
                    continue
                if v.minw()[0] > self.truncate_length:
                    del pte.tensor[k]
                    if verbose:
                        print(" -- removed")
                else:
                    pte.tensor[k].truncate_inplace(self.truncate_length)
                    if verbose:
                        print(" -- truncated")
            if verbose:
                print(f"PTEs: {self._ptes}")

        if verbose:
            print("summed legs: ", summed_legs)
            print("PTEs: ", self._ptes)
        if len(set(self._ptes.values())) > 1:
            if verbose:
                print(
                    f"tensoring {len(set(self._ptes.values()))} disjoint PTEs: {self._ptes}"
                )

            pte_list = list(set(self._ptes.values()))
            pte = pte_list[0]
            for pte2 in pte_list[1:]:
                pte = pte.tensor_product(
                    pte2, verbose=verbose, progress_reporter=progress_reporter
                )

        if len(pte.tensor) > 1:
            if verbose:
                print(f"final PTE is a tensor: {pte}")
                if len(pte.tensor) > 5000:
                    print(
                        f"There are {len(pte.tensor)} keys in the final PTE, skipping printing."
                    )
                else:
                    for k in list(pte.tensor.keys()):
                        v = pte.tensor[k]
                        if verbose:
                            print(Pauli.to_str(*k), end=" ")
                            print(v)

            self._wep = pte.ordered_key_tensor(
                open_legs,
                progress_reporter=progress_reporter,
                verbose=verbose,
            )
        else:
            self._wep = pte.tensor[()]
            if verbose:
                print(f"final scalar wep: {self._wep}")
            self._wep = self._wep.normalize(verbose=verbose)
            if verbose:
                print(f"final normalized scalar wep: {self._wep}")
        return self._wep

    def stabilizer_enumerator(
        self,
        verbose: bool = False,
        progress_reporter: ProgressReporter = DummyProgressReporter(),
    ) -> Dict[int, int]:
        """Compute the stabilizer weight enumerator.

        Computes the weight enumerator polynomial and returns it as a dictionary
        mapping weights to coefficients. This is a convenience method that
        calls stabilizer_enumerator_polynomial() and extracts the dictionary.

        Args:
            verbose: If True, print verbose output.
            progress_reporter: Progress reporter for tracking computation.

        Returns:
            Dict[int, int]: Weight enumerator as a dictionary mapping weights to counts.
        """
        wep = self.stabilizer_enumerator_polynomial(
            verbose=verbose, progress_reporter=progress_reporter
        )
        assert isinstance(wep, UnivariatePoly)
        return wep.dict

    def set_truncate_length(self, truncate_length: int) -> None:
        """Set the truncation length for weight enumerator polynomials.

        Sets the maximum weight to keep in weight enumerator polynomials.
        This affects all subsequent computations and resets any cached results.

        Args:
            truncate_length: Maximum weight to keep in enumerator polynomials.
        """
        self.truncate_length = truncate_length
        self._reset_wep(keep_cot=True)

__init__(nodes, truncate_length=None)

Construct a tensor network.

This class represents a tensor network composed of StabilizerCodeTensorEnumerator nodes that can be contracted together to compute stabilizer enumerator polynomials. The trace ordering can be left to use the original manual ordering or use automated, hyperoptimized contraction ordering using the cotengra library.

The tensor network maintains a collection of nodes (tensors) and traces (contraction operations between nodes). It can compute weight enumerator polynomials for stabilizer codes by contracting the network according to the specified traces.

Parameters:

Name Type Description Default
nodes Union[Iterable[StabilizerCodeTensorEnumerator], Dict[TensorId, StabilizerCodeTensorEnumerator]]

Dictionary mapping tensor IDs to StabilizerCodeTensorEnumerator objects.

required
truncate_length Optional[int]

Optional maximum length for truncating enumerator polynomials.

None

Raises:

Type Description
ValueError

If the nodes have inconsistent indexing.

ValueError

If there are colliding index values in the nodes.

Source code in planqtn/tensor_network.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    nodes: Union[
        Iterable[StabilizerCodeTensorEnumerator],
        Dict[TensorId, StabilizerCodeTensorEnumerator],
    ],
    truncate_length: Optional[int] = None,
):
    """Construct a tensor network.

    This class represents a tensor network composed of
    [`StabilizerCodeTensorEnumerator`][planqtn.StabilizerCodeTensorEnumerator]
    nodes that can be contracted together to compute stabilizer enumerator polynomials.
    The trace ordering can be left to use the original manual ordering or use automated,
    hyperoptimized contraction ordering using the `cotengra` library.

    The tensor network maintains a collection of nodes (tensors) and traces (contraction
    operations between nodes). It can compute weight enumerator polynomials for
    stabilizer codes by contracting the network according to the specified traces.

    Args:
        nodes: Dictionary mapping tensor IDs to
            [`StabilizerCodeTensorEnumerator`][planqtn.StabilizerCodeTensorEnumerator] objects.
        truncate_length: Optional maximum length for truncating enumerator polynomials.

    Raises:
        ValueError: If the nodes have inconsistent indexing.
        ValueError: If there are colliding index values in the nodes.
    """
    if isinstance(nodes, dict):
        for k, v in nodes.items():
            if k != v.tensor_id:
                raise ValueError(
                    f"Nodes dict passed in with inconsitent indexing, "
                    f"{k} != {v.tensor_id} for {v}."
                )
        self.nodes: Dict[TensorId, StabilizerCodeTensorEnumerator] = nodes
    else:
        nodes_dict = {node.tensor_id: node for node in nodes}
        if len(nodes_dict) < len(list(nodes)):
            raise ValueError(f"There are colliding index values of nodes: {nodes}")
        self.nodes = nodes_dict

    self._traces: List[Trace] = []
    self._cot_tree = None
    self._cot_traces: Optional[List[Trace]] = None

    self._legs_left_to_join: Dict[TensorId, List[TensorLeg]] = {
        idx: [] for idx in self.nodes.keys()
    }
    # self.open_legs = [n.legs for n in self.nodes]

    self._wep: Optional[TensorEnumerator | UnivariatePoly] = None
    self._ptes: Dict[TensorId, _PartiallyTracedEnumerator] = {}
    self._coset: Optional[GF2] = None
    self.truncate_length: Optional[int] = truncate_length

analyze_traces(cotengra=False, each_step=False, details=False, **cotengra_opts)

Analyze the trace operations and optionally optimize the contraction path.

Analyzes the current trace schedule and can optionally use cotengra to find an optimal contraction path. This is useful for understanding the computational complexity of the tensor network contraction.

Parameters:

Name Type Description Default
cotengra bool

If True, use cotengra to optimize the contraction path.

False
each_step bool

If True, print details for each contraction step.

False
details bool

If True, print detailed analysis information.

False
**cotengra_opts Any

Additional options to pass to cotengra.

{}

Returns:

Type Description
Tuple[ContractionTree, int]

Tuple[ctg.ContractionTree, int]: The contraction tree and total cost.

Source code in planqtn/tensor_network.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def analyze_traces(
    self,
    cotengra: bool = False,
    each_step: bool = False,
    details: bool = False,
    **cotengra_opts: Any,
) -> Tuple[ctg.ContractionTree, int]:
    """Analyze the trace operations and optionally optimize the contraction path.

    Analyzes the current trace schedule and can optionally use cotengra to
    find an optimal contraction path. This is useful for understanding the
    computational complexity of the tensor network contraction.

    Args:
        cotengra: If True, use cotengra to optimize the contraction path.
        each_step: If True, print details for each contraction step.
        details: If True, print detailed analysis information.
        **cotengra_opts: Additional options to pass to cotengra.

    Returns:
        Tuple[ctg.ContractionTree, int]: The contraction tree and total cost.
    """
    free_legs, leg_indices, index_to_legs = self._collect_legs()
    tree = None

    node_to_free_legs = defaultdict(list)
    for leg in free_legs:
        for node_idx, node in self.nodes.items():
            if leg in node.legs:
                node_to_free_legs[node.tensor_id].append(leg)

    new_tn = TensorNetwork(deepcopy(self.nodes))

    # pylint: disable=W0212
    new_tn._traces = deepcopy(self._traces)
    if cotengra:

        new_tn._traces, tree = self._cotengra_contraction(
            free_legs,
            leg_indices,
            index_to_legs,
            details,
            TqdmProgressReporter() if details else DummyProgressReporter(),
            **cotengra_opts,
        )
    else:
        tree = self._cotengra_tree_from_traces(free_legs, leg_indices)

    # pylint: disable=W0212
    new_tn._legs_left_to_join = deepcopy(self._legs_left_to_join)

    pte_nodes: Dict[TensorId, int] = {}
    max_pte_legs = 0
    if details:
        print(
            "========================== ======= === === === == ==============================="
        )
        print(
            "========================== TRACE SCHEDULE ANALYSIS ============================="
        )
        print(
            "========================== ======= === === === == ==============================="
        )
        print(
            f"    Total legs to trace: "
            f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
        )
    pte_leg_numbers: Dict[TensorId, int] = defaultdict(int)

    for node_idx1, node_idx2, join_legs1, join_legs2 in new_tn._traces:
        if each_step:
            print(
                f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
            )

        for leg in join_legs1:
            new_tn._legs_left_to_join[node_idx1].remove(leg)
        for leg in join_legs2:
            new_tn._legs_left_to_join[node_idx2].remove(leg)

        if node_idx1 not in pte_nodes and node_idx2 not in pte_nodes:
            next_pte = 0 if len(pte_nodes) == 0 else max(pte_nodes.values()) + 1
            if each_step:
                print(f"New PTE: {next_pte}")
            pte_nodes[node_idx1] = next_pte
            pte_nodes[node_idx2] = next_pte
        elif node_idx1 in pte_nodes and node_idx2 not in pte_nodes:
            pte_nodes[node_idx2] = pte_nodes[node_idx1]
        elif node_idx2 in pte_nodes and node_idx1 not in pte_nodes:
            pte_nodes[node_idx1] = pte_nodes[node_idx2]
        elif pte_nodes[node_idx1] == pte_nodes[node_idx2]:
            if each_step:
                print(f"self trace in PTE {pte_nodes[node_idx1]}")
        else:
            if each_step:
                print(f"MERGE of {pte_nodes[node_idx1]} and {pte_nodes[node_idx2]}")
            removed_pte = pte_nodes[node_idx2]
            merged_pte = pte_nodes[node_idx1]
            for node_idx, pte_node in pte_nodes.items():
                if pte_node == removed_pte:
                    pte_nodes[node_idx] = merged_pte

        if details:
            print(f"    pte nodes: {pte_nodes}")
        if each_step:
            print(
                f"    Total legs to trace: "
                f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
            )

        pte_leg_numbers = defaultdict(int)

        for node_idx, pte_node in pte_nodes.items():
            pte_leg_numbers[pte_node] += len(new_tn._legs_left_to_join[node_idx])

        if each_step:
            print(f"     PTEs num tracable legs: {dict(pte_leg_numbers)}")

        biggest_legs = max(pte_leg_numbers.values())

        max_pte_legs = max(max_pte_legs, biggest_legs)
        if each_step:
            print(f"    Biggest PTE legs: {biggest_legs} vs MAX: {max_pte_legs}")
    if details:
        print("=== Final state ==== ")
        print(f"pte nodes: {pte_nodes}")

        print(
            f"all nodes {set(pte_nodes.keys()) == set(new_tn.nodes.keys())} "
            f"and all nodes are in a single PTE: {len(set(pte_nodes.values())) == 1}"
        )
        print(
            f"Total legs to trace: "
            f"{sum(len(legs) for legs in new_tn._legs_left_to_join.values())}"
        )
        print(f"PTEs num tracable legs: {dict(pte_leg_numbers)}")
        print(f"Maximum PTE legs: {max_pte_legs}")
    return tree, max_pte_legs

conjoin_nodes(verbose=False, progress_reporter=DummyProgressReporter())

Conjoin all nodes in the tensor network according to the trace schedule.

Executes all the trace operations defined in the tensor network to produce a single tensor enumerator. This tensor enumerator will have the conjoined parity check matrix. However, running weight enumerator calculation on this conjoined node would use the brute force method, and as such would be typically more expensive than using the stabilizer_enumerator_polynomial method.

Parameters:

Name Type Description Default
verbose bool

If True, print verbose output during contraction.

False
progress_reporter ProgressReporter

Progress reporter for tracking the contraction process.

DummyProgressReporter()

Returns:

Name Type Description
StabilizerCodeTensorEnumerator StabilizerCodeTensorEnumerator

The contracted tensor enumerator.

Source code in planqtn/tensor_network.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def conjoin_nodes(
    self,
    verbose: bool = False,
    progress_reporter: ProgressReporter = DummyProgressReporter(),
) -> "StabilizerCodeTensorEnumerator":
    """Conjoin all nodes in the tensor network according to the trace schedule.

    Executes all the trace operations defined in the tensor network to produce
    a single tensor enumerator. This tensor enumerator will have the conjoined parity check
    matrix. However, running weight enumerator calculation on this conjoined node would use the
    brute force method, and as such would be typically more expensive than using the
    [`stabilizer_enumerator_polynomial`][planqtn.TensorNetwork.stabilizer_enumerator_polynomial]
    method.

    Args:
        verbose: If True, print verbose output during contraction.
        progress_reporter: Progress reporter for tracking the contraction process.

    Returns:
        StabilizerCodeTensorEnumerator: The contracted tensor enumerator.
    """
    # If there's only one node and no traces, return it directly
    if len(self.nodes) == 1 and len(self._traces) == 0:
        return list(self.nodes.values())[0]

    # Map from node_idx to the index of its PTE in ptes list
    nodes = list(self.nodes.values())
    ptes: List[Tuple[StabilizerCodeTensorEnumerator, Set[TensorId]]] = [
        (node, {node.tensor_id}) for node in nodes
    ]
    node_to_pte = {node.tensor_id: i for i, node in enumerate(nodes)}

    for node_idx1, node_idx2, join_legs1, join_legs2 in progress_reporter.iterate(
        self._traces, "Conjoining nodes", len(self._traces)
    ):
        if verbose:
            print(
                f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
            )

        join_legs1 = _index_legs(node_idx1, join_legs1)
        join_legs2 = _index_legs(node_idx2, join_legs2)

        pte1_idx = node_to_pte[node_idx1]
        pte2_idx = node_to_pte[node_idx2]

        # Case 1: Both nodes are in the same PTE
        if pte1_idx == pte2_idx:
            if verbose:
                print(
                    f"Self trace in PTE containing both {node_idx1} and {node_idx2}"
                )
            pte, nodes_in_pte = ptes[pte1_idx]
            new_pte = pte.self_trace(join_legs1, join_legs2)
            ptes[pte1_idx] = (new_pte, nodes_in_pte)

        # Case 2: Nodes are in different PTEs - merge them
        else:
            if verbose:
                print(f"Merging PTEs containing {node_idx1} and {node_idx2}")
            pte1, nodes1 = ptes[pte1_idx]
            pte2, nodes2 = ptes[pte2_idx]
            new_pte = pte1.conjoin(pte2, legs1=join_legs1, legs2=join_legs2)
            merged_nodes = nodes1.union(nodes2)

            # Update the first PTE with merged result
            ptes[pte1_idx] = (new_pte, merged_nodes)
            # Remove the second PTE
            ptes.pop(pte2_idx)

            # Update node_to_pte mappings
            for node_idx in nodes2:
                node_to_pte[node_idx] = pte1_idx
            # Adjust indices for all nodes in PTEs after the removed one
            for node_idx, pte_idx in node_to_pte.items():
                if pte_idx > pte2_idx:
                    node_to_pte[node_idx] = pte_idx - 1

        if verbose:
            print("H:")
            sprint(ptes[0][0].h)

    # If we have multiple components at the end, tensor them together
    if len(ptes) > 1:
        for other in ptes[1:]:
            ptes[0] = (ptes[0][0].tensor_with(other[0]), ptes[0][1].union(other[1]))

    return ptes[0][0]

n_qubits()

Get the total number of qubits in the tensor network.

Returns the total number of qubits represented by this tensor network. This is an abstract method that must be implemented by subclasses that have a representation for qubits.

Returns:

Name Type Description
int int

Total number of qubits.

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in planqtn/tensor_network.py
192
193
194
195
196
197
198
199
200
201
202
203
204
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns the total number of qubits represented by this tensor network. This is an abstract
    method that must be implemented by subclasses that have a representation for qubits.

    Returns:
        int: Total number of qubits.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """  # noqa: DAR202
    raise NotImplementedError(f"n_qubits() is not implemented for {type(self)}")

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

This method maps a global qubit index to the specific node and leg that represents that qubit in the tensor network. This is an abstract method that must be implemented by subclasses that have a representation for qubits.

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Name Type Description
node_id TensorId

Node ID and leg that represent the qubit.

leg TensorLeg

Leg that represent the qubit.

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in planqtn/tensor_network.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    This method maps a global qubit index to the specific node and leg
    that represents that qubit in the tensor network. This is an abstract method
    that must be implemented by subclasses that have a representation for qubits.

    Args:
        q: Global qubit index.

    Returns:
        node_id: Node ID and leg that represent the qubit.
        leg: Leg that represent the qubit.


    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """  # noqa: DAR202
    raise NotImplementedError(
        f"qubit_to_node_and_leg() is not implemented for {type(self)}!"
    )

self_trace(node_idx1, node_idx2, join_legs1, join_legs2)

Add a trace operation between two nodes in the tensor network.

Defines a contraction between two nodes by specifying which legs to join. This operation is added to the trace schedule and will be executed when the tensor network is contracted.

Parameters:

Name Type Description Default
node_idx1 TensorId

ID of the first node to trace.

required
node_idx2 TensorId

ID of the second node to trace.

required
join_legs1 Sequence[int | TensorLeg]

Legs from the first node to contract.

required
join_legs2 Sequence[int | TensorLeg]

Legs from the second node to contract.

required

Raises:

Type Description
ValueError

If the weight enumerator has already been computed.

Source code in planqtn/tensor_network.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def self_trace(
    self,
    node_idx1: TensorId,
    node_idx2: TensorId,
    join_legs1: Sequence[int | TensorLeg],
    join_legs2: Sequence[int | TensorLeg],
) -> None:
    """Add a trace operation between two nodes in the tensor network.

    Defines a contraction between two nodes by specifying which legs to join.
    This operation is added to the trace schedule and will be executed when
    the tensor network is contracted.

    Args:
        node_idx1: ID of the first node to trace.
        node_idx2: ID of the second node to trace.
        join_legs1: Legs from the first node to contract.
        join_legs2: Legs from the second node to contract.

    Raises:
        ValueError: If the weight enumerator has already been computed.
    """
    if self._wep is not None:
        raise ValueError(
            "Tensor network weight enumerator is already traced no new tracing schedule is "
            "allowed."
        )
    join_legs1_indexed = _index_legs(node_idx1, join_legs1)
    join_legs2_indexed = _index_legs(node_idx2, join_legs2)

    # print(f"adding trace {node_idx1, node_idx2, join_legs1, join_legs2}")
    self._traces.append(
        (node_idx1, node_idx2, join_legs1_indexed, join_legs2_indexed)
    )

    self._legs_left_to_join[node_idx1] += join_legs1_indexed
    self._legs_left_to_join[node_idx2] += join_legs2_indexed

set_coset(coset_error)

Set the coset error for the tensor network.

Sets the coset error that will be used for coset weight enumerator calculations. The coset error should follow the qubit numbering defined in qubit_to_node_and_leg which maps the index to a node ID. Both qubit_to_node_and_leg and n_qubits are abstract classes, and thus this method can only be called on a subclass that implements these methods, see the planqtn.networks module for examples.

There are two possible ways to pass the coset_error:

  • a tuple of two lists of qubit indices, one for the Z errors and one for the X errors
  • a galois.GF2 array of length 2 * tn.n_qubits() for the tn tensor network. This is a symplectic operator representation on the n qubits of the tensor network.

Parameters:

Name Type Description Default
coset_error GF2 | Tuple[List[int], List[int]]

The coset error specification.

required

Raises:

Type Description
ValueError

If the coset error has the wrong number of qubits.

Source code in planqtn/tensor_network.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def set_coset(self, coset_error: GF2 | Tuple[List[int], List[int]]) -> None:
    """Set the coset error for the tensor network.

    Sets the coset error that will be used for coset weight enumerator calculations.
    The coset error should follow the qubit numbering defined in
     [`qubit_to_node_and_leg`][planqtn.TensorNetwork.qubit_to_node_and_leg] which maps the index
    to a node ID. Both [`qubit_to_node_and_leg`][planqtn.TensorNetwork.qubit_to_node_and_leg]
    and [`n_qubits`][planqtn.TensorNetwork.n_qubits] are abstract classes, and thus this method
    can only be called on a subclass that implements these methods, see the
    [`planqtn.networks`][planqtn.networks] module for examples.

    There are two possible ways to pass the coset_error:

    - a tuple of two lists of qubit indices, one for the `Z` errors and one for the `X` errors
    - a `galois.GF2` array of length `2 * tn.n_qubits()` for the `tn` tensor network. This is a
        symplectic operator representation on the `n` qubits of the tensor network.

    Args:
        coset_error: The coset error specification.

    Raises:
        ValueError: If the coset error has the wrong number of qubits.
    """
    self._reset_wep(keep_cot=True)

    self._coset = GF2.Zeros(2 * self.n_qubits())

    if isinstance(coset_error, tuple):
        for i in coset_error[0]:
            self._coset[i] = 1
        for i in coset_error[1]:
            self._coset[i + self.n_qubits()] = 1
    elif isinstance(coset_error, GF2):
        self._coset = coset_error

    n = len(self._coset) // 2
    if n != self.n_qubits():
        raise ValueError(
            f"Can't set coset with {n} qubits for a {self.n_qubits()} qubit code."
        )

    z_errors = np.argwhere(self._coset[n:] == 1).flatten()
    x_errors = np.argwhere(self._coset[:n] == 1).flatten()

    node_legs_to_flip = defaultdict(list)

    for q in range(n):
        is_z = q in z_errors
        is_x = q in x_errors
        node_idx, leg = self.qubit_to_node_and_leg(q)

        self.nodes[node_idx].coset_flipped_legs = []
        if not is_z and not is_x:
            continue
        # print(f"q{q} -> {node_idx, leg}")
        node_legs_to_flip[node_idx].append((leg, GF2([is_x, is_z])))

    for node_idx, coset_flipped_legs in node_legs_to_flip.items():

        # print(node_idx, f" will have flipped {coset_flipped_legs}")

        self.nodes[node_idx] = self.nodes[node_idx].with_coset_flipped_legs(
            coset_flipped_legs
        )

set_truncate_length(truncate_length)

Set the truncation length for weight enumerator polynomials.

Sets the maximum weight to keep in weight enumerator polynomials. This affects all subsequent computations and resets any cached results.

Parameters:

Name Type Description Default
truncate_length int

Maximum weight to keep in enumerator polynomials.

required
Source code in planqtn/tensor_network.py
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
def set_truncate_length(self, truncate_length: int) -> None:
    """Set the truncation length for weight enumerator polynomials.

    Sets the maximum weight to keep in weight enumerator polynomials.
    This affects all subsequent computations and resets any cached results.

    Args:
        truncate_length: Maximum weight to keep in enumerator polynomials.
    """
    self.truncate_length = truncate_length
    self._reset_wep(keep_cot=True)

stabilizer_enumerator(verbose=False, progress_reporter=DummyProgressReporter())

Compute the stabilizer weight enumerator.

Computes the weight enumerator polynomial and returns it as a dictionary mapping weights to coefficients. This is a convenience method that calls stabilizer_enumerator_polynomial() and extracts the dictionary.

Parameters:

Name Type Description Default
verbose bool

If True, print verbose output.

False
progress_reporter ProgressReporter

Progress reporter for tracking computation.

DummyProgressReporter()

Returns:

Type Description
Dict[int, int]

Dict[int, int]: Weight enumerator as a dictionary mapping weights to counts.

Source code in planqtn/tensor_network.py
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
def stabilizer_enumerator(
    self,
    verbose: bool = False,
    progress_reporter: ProgressReporter = DummyProgressReporter(),
) -> Dict[int, int]:
    """Compute the stabilizer weight enumerator.

    Computes the weight enumerator polynomial and returns it as a dictionary
    mapping weights to coefficients. This is a convenience method that
    calls stabilizer_enumerator_polynomial() and extracts the dictionary.

    Args:
        verbose: If True, print verbose output.
        progress_reporter: Progress reporter for tracking computation.

    Returns:
        Dict[int, int]: Weight enumerator as a dictionary mapping weights to counts.
    """
    wep = self.stabilizer_enumerator_polynomial(
        verbose=verbose, progress_reporter=progress_reporter
    )
    assert isinstance(wep, UnivariatePoly)
    return wep.dict

stabilizer_enumerator_polynomial(open_legs=(), verbose=False, progress_reporter=DummyProgressReporter(), cotengra=True)

Returns the reduced stabilizer enumerator polynomial for the tensor network.

If open_legs is not empty, then the returned tensor enumerator polynomial is a dictionary of tensor keys to UnivariatePoly objects.

Parameters:

Name Type Description Default
open_legs Sequence[TensorLeg]

The legs that are open in the tensor network. If empty, the result is a scalar weightenumerator polynomial of type UnivariatePoly,otherwise it is a dictionary of TensorEnumeratorKey keys to UnivariatePoly objects.

()
verbose bool

If True, print verbose output.

False
progress_reporter ProgressReporter

The progress reporter to use, defaults to no progress reporting (DummyProgressReporter), can be set to TqdmProgressReporter for progress reporting on the console, or any other custom ProgressReporter subclass.

DummyProgressReporter()
cotengra bool

If True, use cotengra to contract the tensor network, otherwise use the order the traces were constructed.

True

Returns:

Name Type Description
TensorEnumerator TensorEnumerator | UnivariatePoly

The reduced stabilizer enumerator polynomial for the tensor network.

Source code in planqtn/tensor_network.py
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
def stabilizer_enumerator_polynomial(
    self,
    open_legs: Sequence[TensorLeg] = (),
    verbose: bool = False,
    progress_reporter: ProgressReporter = DummyProgressReporter(),
    cotengra: bool = True,
) -> TensorEnumerator | UnivariatePoly:
    """Returns the reduced stabilizer enumerator polynomial for the tensor network.

    If open_legs is not empty, then the returned tensor enumerator polynomial is a dictionary of
    tensor keys to UnivariatePoly objects.

    Args:
        open_legs: The legs that are open in the tensor network. If empty, the result is a
                   scalar weightenumerator polynomial of type `UnivariatePoly`,otherwise it is a
                   dictionary of `TensorEnumeratorKey` keys to `UnivariatePoly` objects.
        verbose: If True, print verbose output.
        progress_reporter: The progress reporter to use, defaults to no progress reporting
                          (`DummyProgressReporter`), can be set to `TqdmProgressReporter` for
                          progress reporting on the console, or any other custom
                          `ProgressReporter` subclass.
        cotengra: If True, use cotengra to contract the tensor network, otherwise use the order
                  the traces were constructed.

    Returns:
        TensorEnumerator: The reduced stabilizer enumerator polynomial for the tensor network.
    """
    if self._wep is not None:
        return self._wep

    assert (
        progress_reporter is not None
    ), "Progress reporter must be provided, it is None"

    with progress_reporter.enter_phase("collecting legs"):
        free_legs, leg_indices, index_to_legs = self._collect_legs()

    open_legs_per_node = defaultdict(list)
    for node_idx, node in self.nodes.items():
        for leg in node.legs:
            if leg not in free_legs:
                open_legs_per_node[node_idx].append(_index_leg(node_idx, leg))

    for node_idx, leg_index in open_legs:
        open_legs_per_node[node_idx].append(_index_leg(node_idx, leg_index))

    if verbose:
        print("open_legs_per_node", open_legs_per_node)
    traces = self._traces
    if cotengra and len(self.nodes) > 0 and len(self._traces) > 0:
        with progress_reporter.enter_phase("cotengra contraction"):
            traces, _ = self._cotengra_contraction(
                free_legs, leg_indices, index_to_legs, verbose, progress_reporter
            )
    summed_legs = [leg for leg in free_legs if leg not in open_legs]

    if len(self._traces) == 0 and len(self.nodes) == 1:
        return list(self.nodes.items())[0][1].stabilizer_enumerator_polynomial(
            verbose=verbose,
            progress_reporter=progress_reporter,
            truncate_length=self.truncate_length,
            open_legs=open_legs,
        )

    # parity_check_enums = {}

    for node_idx, node in self.nodes.items():
        traced_legs = open_legs_per_node[node_idx]
        # TODO: figure out tensor caching
        # traced_leg_indices = "".join(
        #     [str(i) for i in sorted([node.legs.index(leg) for leg in traced_legs])]
        # )
        # hkey = sstr(gauss(node.h)) + ";" + traced_leg_indices
        # if hkey not in parity_check_enums:
        #     parity_check_enums[hkey] = node.stabilizer_enumerator_polynomial(
        #         open_legs=traced_legs
        #     )
        # else:
        #     print("Found one!")
        #     calc = node.stabilizer_enumerator_polynomial(open_legs=traced_legs)
        #     assert (
        #         calc == parity_check_enums[hkey]
        #     ), f"for key {hkey}\n calc\n{calc}\n vs retrieved\n{parity_check_enums[hkey]}"

        # call the right type here...
        tensor = node.stabilizer_enumerator_polynomial(
            open_legs=traced_legs,
            verbose=verbose,
            progress_reporter=progress_reporter,
            truncate_length=self.truncate_length,
        )
        if isinstance(tensor, UnivariatePoly):
            tensor = {(): tensor}
        self._ptes[node_idx] = _PartiallyTracedEnumerator(
            nodes={node_idx},
            tracable_legs=open_legs_per_node[node_idx],
            tensor=tensor,  # deepcopy(parity_check_enums[hkey]),
            truncate_length=self.truncate_length,
        )

    for node_idx1, node_idx2, join_legs1, join_legs2 in progress_reporter.iterate(
        traces, f"Tracing {len(traces)} legs", len(traces)
    ):
        if verbose:
            print(
                f"==== trace {node_idx1, node_idx2, join_legs1, join_legs2} ==== "
            )
            print(
                f"Total legs left to join: "
                f"{sum(len(legs) for legs in self._legs_left_to_join.values())}"
            )
        node1_pte = self._ptes[node_idx1]
        node2_pte = self._ptes[node_idx2]

        # print(f"PTEs: {node1_pte}, {node2_pte}")
        # check that the length of the tensor is a power of 4

        if node1_pte == node2_pte:
            # both nodes are in the same PTE!
            if verbose:
                print(f"self trace within PTE {node1_pte}")
            pte = node1_pte.self_trace(
                join_legs1=[
                    (node_idx1, leg) if isinstance(leg, int) else leg
                    for leg in join_legs1
                ],
                join_legs2=[
                    (node_idx2, leg) if isinstance(leg, int) else leg
                    for leg in join_legs2
                ],
                progress_reporter=progress_reporter,
                verbose=verbose,
            )
            for node_idx in pte.nodes:
                self._ptes[node_idx] = pte
            self._legs_left_to_join[node_idx1] = [
                leg
                for leg in self._legs_left_to_join[node_idx1]
                if leg not in join_legs1
            ]
            self._legs_left_to_join[node_idx2] = [
                leg
                for leg in self._legs_left_to_join[node_idx2]
                if leg not in join_legs2
            ]
        else:
            if verbose:
                print(f"MERGING two components {node1_pte} and {node2_pte}")
                print(f"node1_pte {node1_pte}:")
                for k in list(node1_pte.tensor.keys()):
                    v = node1_pte.tensor[k]
                    print(Pauli.to_str(*k), end=" ")
                    print(v)
                print(f"node2_pte {node2_pte}:")
                for k in list(node2_pte.tensor.keys()):
                    v = node2_pte.tensor[k]
                    print(Pauli.to_str(*k), end=" ")
                    print(v)
            pte = node1_pte.merge_with(
                node2_pte,
                join_legs1=[
                    (node_idx1, leg) if isinstance(leg, int) else leg
                    for leg in join_legs1
                ],
                join_legs2=[
                    (node_idx2, leg) if isinstance(leg, int) else leg
                    for leg in join_legs2
                ],
                verbose=verbose,
                progress_reporter=progress_reporter,
            )

            for node_idx in pte.nodes:
                self._ptes[node_idx] = pte
            self._legs_left_to_join[node_idx1] = [
                leg
                for leg in self._legs_left_to_join[node_idx1]
                if leg not in join_legs1
            ]
            self._legs_left_to_join[node_idx2] = [
                leg
                for leg in self._legs_left_to_join[node_idx2]
                if leg not in join_legs2
            ]

        node1_pte = self._ptes[node_idx1]

        if verbose:
            print(
                f"PTE nodes: {node1_pte.nodes if node1_pte is not None else None}"
            )
            print(
                f"PTE tracable legs: "
                f"{node1_pte.tracable_legs if node1_pte is not None else None}"
            )
        if verbose:
            print("PTE tensor: ")
        for k in list(node1_pte.tensor.keys() if node1_pte is not None else []):
            v = node1_pte.tensor[k] if node1_pte is not None else UnivariatePoly()
            # if not 0 in v:
            #     continue
            if verbose:
                print(Pauli.to_str(*k), end=" ")
                print(v, end="")
            if self.truncate_length is None:
                continue
            if v.minw()[0] > self.truncate_length:
                del pte.tensor[k]
                if verbose:
                    print(" -- removed")
            else:
                pte.tensor[k].truncate_inplace(self.truncate_length)
                if verbose:
                    print(" -- truncated")
        if verbose:
            print(f"PTEs: {self._ptes}")

    if verbose:
        print("summed legs: ", summed_legs)
        print("PTEs: ", self._ptes)
    if len(set(self._ptes.values())) > 1:
        if verbose:
            print(
                f"tensoring {len(set(self._ptes.values()))} disjoint PTEs: {self._ptes}"
            )

        pte_list = list(set(self._ptes.values()))
        pte = pte_list[0]
        for pte2 in pte_list[1:]:
            pte = pte.tensor_product(
                pte2, verbose=verbose, progress_reporter=progress_reporter
            )

    if len(pte.tensor) > 1:
        if verbose:
            print(f"final PTE is a tensor: {pte}")
            if len(pte.tensor) > 5000:
                print(
                    f"There are {len(pte.tensor)} keys in the final PTE, skipping printing."
                )
            else:
                for k in list(pte.tensor.keys()):
                    v = pte.tensor[k]
                    if verbose:
                        print(Pauli.to_str(*k), end=" ")
                        print(v)

        self._wep = pte.ordered_key_tensor(
            open_legs,
            progress_reporter=progress_reporter,
            verbose=verbose,
        )
    else:
        self._wep = pte.tensor[()]
        if verbose:
            print(f"final scalar wep: {self._wep}")
        self._wep = self._wep.normalize(verbose=verbose)
        if verbose:
            print(f"final normalized scalar wep: {self._wep}")
    return self._wep

traces_to_dot()

Print the tensor network traces in DOT format.

Prints the traces (contractions) between nodes in a format that can be used to visualize the tensor network structure. Each trace is printed as a directed edge between nodes.

Source code in planqtn/tensor_network.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def traces_to_dot(self) -> None:
    """Print the tensor network traces in DOT format.

    Prints the traces (contractions) between nodes in a format that can be
    used to visualize the tensor network structure. Each trace is printed
    as a directed edge between nodes.
    """
    print("-----")
    # print(self.open_legs)
    # for n, legs in enumerate(self.open_legs):
    #     for leg in legs:
    #         print(f"n{n} -> n{n}_{leg}")

    for node_idx1, node_idx2, join_legs1, join_legs2 in self._traces:
        for _ in zip(join_legs1, join_legs2):
            print(f"n{node_idx1} -> n{node_idx2} ")

UnivariatePoly

A class for univariate integer polynomials.

Source code in planqtn/poly.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class UnivariatePoly:
    """A class for univariate integer polynomials."""

    def __init__(
        self, d: Optional[Union["UnivariatePoly", Dict[int, int]]] = None
    ) -> None:
        """Construct a univariate integer polynomial.

        This class represents univariate polynomials as a dictionary mapping
        powers to coefficients. It's specifically designed for weight enumerator
        polynomials, where coefficients are typically integers.

        The class provides basic polynomial operations like addition, multiplication,
        normalization, and MacWilliams dual computation. It also supports truncation
        and homogenization for bivariate polynomials.

        Attributes:
            dict: Dictionary mapping integer powers to integer coefficients.
            num_vars: Number of variables (always 1 for univariate).

        Raises:
            ValueError: If the input is not a dictionary or a UnivariatePoly.

        Example:
            ```python

            >>> # Create a polynomial: 1 + 3x + 2x^2
            >>> poly = UnivariatePoly({0: 1, 1: 3, 2: 2})

            >>> # Add polynomials
            >>> result = poly + UnivariatePoly({1: 1, 3: 1})

            >>> # Multiply by scalar
            >>> scaled = poly * 2

            >>> # Get minimum weight term
            >>> min_weight, coeff = poly.minw()
            >>> min_weight
            0
            >>> coeff
            1

            ```

        Args:
            d: The dictionary of powers and coefficients.
        """
        self.dict: Dict[int, int] = {}
        self.num_vars = 1
        if isinstance(d, UnivariatePoly):
            self.dict.update(d.dict)
        elif d is not None and isinstance(d, dict):
            self.dict.update(d)
            if len(d) > 0:
                first_key = list(self.dict.keys())[0]
                assert isinstance(first_key, int)
        elif d is not None:
            raise ValueError(f"Unrecognized type: {type(d)}")

    def is_scalar(self) -> bool:
        """Check if the polynomial is a scalar (constant term only).

        Returns:
            bool: True if the polynomial has only a constant term (power 0).
        """
        return len(self.dict) == 1 and set(self.dict.keys()) == {0}

    def add_inplace(self, other: "UnivariatePoly") -> None:
        """Add another polynomial to this one in-place.

        Args:
            other: The polynomial to add to this one.

        Raises:
            AssertionError: If the polynomials have different numbers of variables.
        """
        assert other.num_vars == self.num_vars
        for k, v in other.dict.items():
            self.dict[k] = self.dict.get(k, 0) + v

    def __add__(self, other: "UnivariatePoly") -> "UnivariatePoly":
        assert other.num_vars == self.num_vars
        res = UnivariatePoly(self.dict)
        for k, v in other.dict.items():
            res.dict[k] = res.dict.get(k, 0) + v
        return res

    def minw(self) -> Tuple[Any, int]:
        """Get the minimum weight term and its coefficient.

        Returns:
            Tuple containing the minimum power and its coefficient.
        """
        min_w = min(self.dict.keys())
        min_coeff = self.dict[min_w]
        return min_w, min_coeff

    def leading_order_poly(self) -> "UnivariatePoly":
        """Get the polynomial containing only the minimum weight term.

        Returns:
            UnivariatePoly: A new polynomial with only the minimum weight term.
        """
        min_w = min(self.dict.keys())
        min_coeff = self.dict[min_w]
        return UnivariatePoly({min_w: min_coeff})

    def __getitem__(self, i: Any) -> int:
        return self.dict.get(i, 0)

    def items(self) -> Generator[Tuple[Any, int], None, None]:
        """Yield items from the polynomial.

        Yields:
            Tuple[Any, int]: A tuple of the power and coefficient.
        """
        yield from self.dict.items()

    def __len__(self) -> int:
        return len(self.dict)

    def normalize(self, verbose: bool = False) -> "UnivariatePoly":
        """Normalize the polynomial by dividing by the constant term if it's greater than 1.

        Args:
            verbose: If True, print normalization information.

        Returns:
            UnivariatePoly: The normalized polynomial.
        """
        if 0 in self.dict and self.dict[0] > 1:
            if verbose:
                print(f"normalizing WEP by 1/{self.dict[0]}")
            return self / self.dict[0]
        return self

    def __str__(self) -> str:
        return (
            "{"
            + ", ".join([f"{w}:{self.dict[w]}" for w in sorted(list(self.dict.keys()))])
            + "}"
        )

    def __repr__(self) -> str:
        return f"UnivariatePoly({repr(self.dict)})"

    def __truediv__(self, n: int) -> "UnivariatePoly":
        if isinstance(n, int):
            return UnivariatePoly({k: int(v // n) for k, v in self.dict.items()})
        raise TypeError(f"Cannot divide UnivariatePoly by {type(n)}")

    def __eq__(self, value: object) -> bool:
        if isinstance(value, (int, float)):
            return self.dict[0] == value
        if isinstance(value, UnivariatePoly):
            return self.dict == value.dict
        return False

    def __hash__(self) -> int:
        return hash(self.dict)

    def __mul__(self, n: Union[int, float, "UnivariatePoly"]) -> "UnivariatePoly":
        if isinstance(n, (int, float)):
            return UnivariatePoly({k: int(n * v) for k, v in self.dict.items()})
        if isinstance(n, UnivariatePoly):
            res = UnivariatePoly()
            for d1, coeff1 in self.dict.items():
                for d2, coeff2 in n.dict.items():
                    res.dict[d1 + d2] = res.dict.get(d1 + d2, 0) + coeff1 * coeff2
            return res
        raise TypeError(f"Cannot multiply UnivariatePoly by {type(n)}")

    def _homogenize(self, n: int) -> "BivariatePoly":
        """Homogenize a univariate polynomial to a bivariate polynomial.

        Converts A(z) to A(w,z) = w^n * A(z/w), where w represents the dual weight
        and z represents the actual weight. This is used in MacWilliams duality.

        Args:
            n: The degree of homogenization.

        Returns:
            BivariatePoly: The homogenized bivariate polynomial.
        """
        return BivariatePoly({Monomial((n - k, k)): v for k, v in self.dict.items()})

    def truncate_inplace(self, n: int) -> None:
        """Truncate the polynomial to terms with power <= n in-place.

        Args:
            n: Maximum power to keep in the polynomial.
        """
        self.dict = {k: v for k, v in self.dict.items() if k <= n}

    def macwilliams_dual(
        self, n: int, k: int, to_normalizer: bool = True
    ) -> "UnivariatePoly":
        """Convert this weight enumerator polynomial to its MacWilliams dual.

        The MacWilliams duality theorem relates the weight enumerator polynomial
        of a code to that of its dual code. This method implements the transformation
        A(z) -> B(z) = (1 + z)^n * A((1 - z)/(1 + z)) / 2^k.

        Args:
            n: Length of the code.
            k: Dimension of the code.
            to_normalizer: If True, compute the normalizer enumerator polynomial.
                          If False, compute the weight enumerator polynomial.
                          This affects the normalization factors.

        Returns:
            UnivariatePoly: The MacWilliams dual weight enumerator polynomial.
        """
        factors = [4**k, 2**k] if to_normalizer else [2**k, 4**k]
        homogenized: BivariatePoly = self._homogenize(n) * factors[0]
        z, w = symbols("w z")
        sp_homogenized = homogenized.to_sympy([w, z])

        sympy_substituted = Poly(
            sp_homogenized.subs({w: (w + 3 * z) / 2, z: (w - z) / 2}).simplify()
            / factors[1],
            w,
            z,
        )

        monomial_powers_substituted: BivariatePoly = BivariatePoly.from_sympy(
            sympy_substituted
        )

        single_var_dict = {}

        for key, value in monomial_powers_substituted.dict.items():
            assert key[1] not in single_var_dict
            single_var_dict[key[1]] = value

        return UnivariatePoly(single_var_dict)

__init__(d=None)

Construct a univariate integer polynomial.

This class represents univariate polynomials as a dictionary mapping powers to coefficients. It's specifically designed for weight enumerator polynomials, where coefficients are typically integers.

The class provides basic polynomial operations like addition, multiplication, normalization, and MacWilliams dual computation. It also supports truncation and homogenization for bivariate polynomials.

Attributes:

Name Type Description
dict

Dictionary mapping integer powers to integer coefficients.

num_vars

Number of variables (always 1 for univariate).

Raises:

Type Description
ValueError

If the input is not a dictionary or a UnivariatePoly.

Example
>>> # Create a polynomial: 1 + 3x + 2x^2
>>> poly = UnivariatePoly({0: 1, 1: 3, 2: 2})

>>> # Add polynomials
>>> result = poly + UnivariatePoly({1: 1, 3: 1})

>>> # Multiply by scalar
>>> scaled = poly * 2

>>> # Get minimum weight term
>>> min_weight, coeff = poly.minw()
>>> min_weight
0
>>> coeff
1

Parameters:

Name Type Description Default
d Optional[Union[UnivariatePoly, Dict[int, int]]]

The dictionary of powers and coefficients.

None
Source code in planqtn/poly.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def __init__(
    self, d: Optional[Union["UnivariatePoly", Dict[int, int]]] = None
) -> None:
    """Construct a univariate integer polynomial.

    This class represents univariate polynomials as a dictionary mapping
    powers to coefficients. It's specifically designed for weight enumerator
    polynomials, where coefficients are typically integers.

    The class provides basic polynomial operations like addition, multiplication,
    normalization, and MacWilliams dual computation. It also supports truncation
    and homogenization for bivariate polynomials.

    Attributes:
        dict: Dictionary mapping integer powers to integer coefficients.
        num_vars: Number of variables (always 1 for univariate).

    Raises:
        ValueError: If the input is not a dictionary or a UnivariatePoly.

    Example:
        ```python

        >>> # Create a polynomial: 1 + 3x + 2x^2
        >>> poly = UnivariatePoly({0: 1, 1: 3, 2: 2})

        >>> # Add polynomials
        >>> result = poly + UnivariatePoly({1: 1, 3: 1})

        >>> # Multiply by scalar
        >>> scaled = poly * 2

        >>> # Get minimum weight term
        >>> min_weight, coeff = poly.minw()
        >>> min_weight
        0
        >>> coeff
        1

        ```

    Args:
        d: The dictionary of powers and coefficients.
    """
    self.dict: Dict[int, int] = {}
    self.num_vars = 1
    if isinstance(d, UnivariatePoly):
        self.dict.update(d.dict)
    elif d is not None and isinstance(d, dict):
        self.dict.update(d)
        if len(d) > 0:
            first_key = list(self.dict.keys())[0]
            assert isinstance(first_key, int)
    elif d is not None:
        raise ValueError(f"Unrecognized type: {type(d)}")

add_inplace(other)

Add another polynomial to this one in-place.

Parameters:

Name Type Description Default
other UnivariatePoly

The polynomial to add to this one.

required

Raises:

Type Description
AssertionError

If the polynomials have different numbers of variables.

Source code in planqtn/poly.py
236
237
238
239
240
241
242
243
244
245
246
247
def add_inplace(self, other: "UnivariatePoly") -> None:
    """Add another polynomial to this one in-place.

    Args:
        other: The polynomial to add to this one.

    Raises:
        AssertionError: If the polynomials have different numbers of variables.
    """
    assert other.num_vars == self.num_vars
    for k, v in other.dict.items():
        self.dict[k] = self.dict.get(k, 0) + v

is_scalar()

Check if the polynomial is a scalar (constant term only).

Returns:

Name Type Description
bool bool

True if the polynomial has only a constant term (power 0).

Source code in planqtn/poly.py
228
229
230
231
232
233
234
def is_scalar(self) -> bool:
    """Check if the polynomial is a scalar (constant term only).

    Returns:
        bool: True if the polynomial has only a constant term (power 0).
    """
    return len(self.dict) == 1 and set(self.dict.keys()) == {0}

items()

Yield items from the polynomial.

Yields:

Type Description
Tuple[Any, int]

Tuple[Any, int]: A tuple of the power and coefficient.

Source code in planqtn/poly.py
279
280
281
282
283
284
285
def items(self) -> Generator[Tuple[Any, int], None, None]:
    """Yield items from the polynomial.

    Yields:
        Tuple[Any, int]: A tuple of the power and coefficient.
    """
    yield from self.dict.items()

leading_order_poly()

Get the polynomial containing only the minimum weight term.

Returns:

Name Type Description
UnivariatePoly UnivariatePoly

A new polynomial with only the minimum weight term.

Source code in planqtn/poly.py
266
267
268
269
270
271
272
273
274
def leading_order_poly(self) -> "UnivariatePoly":
    """Get the polynomial containing only the minimum weight term.

    Returns:
        UnivariatePoly: A new polynomial with only the minimum weight term.
    """
    min_w = min(self.dict.keys())
    min_coeff = self.dict[min_w]
    return UnivariatePoly({min_w: min_coeff})

macwilliams_dual(n, k, to_normalizer=True)

Convert this weight enumerator polynomial to its MacWilliams dual.

The MacWilliams duality theorem relates the weight enumerator polynomial of a code to that of its dual code. This method implements the transformation A(z) -> B(z) = (1 + z)^n * A((1 - z)/(1 + z)) / 2^k.

Parameters:

Name Type Description Default
n int

Length of the code.

required
k int

Dimension of the code.

required
to_normalizer bool

If True, compute the normalizer enumerator polynomial. If False, compute the weight enumerator polynomial. This affects the normalization factors.

True

Returns:

Name Type Description
UnivariatePoly UnivariatePoly

The MacWilliams dual weight enumerator polynomial.

Source code in planqtn/poly.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def macwilliams_dual(
    self, n: int, k: int, to_normalizer: bool = True
) -> "UnivariatePoly":
    """Convert this weight enumerator polynomial to its MacWilliams dual.

    The MacWilliams duality theorem relates the weight enumerator polynomial
    of a code to that of its dual code. This method implements the transformation
    A(z) -> B(z) = (1 + z)^n * A((1 - z)/(1 + z)) / 2^k.

    Args:
        n: Length of the code.
        k: Dimension of the code.
        to_normalizer: If True, compute the normalizer enumerator polynomial.
                      If False, compute the weight enumerator polynomial.
                      This affects the normalization factors.

    Returns:
        UnivariatePoly: The MacWilliams dual weight enumerator polynomial.
    """
    factors = [4**k, 2**k] if to_normalizer else [2**k, 4**k]
    homogenized: BivariatePoly = self._homogenize(n) * factors[0]
    z, w = symbols("w z")
    sp_homogenized = homogenized.to_sympy([w, z])

    sympy_substituted = Poly(
        sp_homogenized.subs({w: (w + 3 * z) / 2, z: (w - z) / 2}).simplify()
        / factors[1],
        w,
        z,
    )

    monomial_powers_substituted: BivariatePoly = BivariatePoly.from_sympy(
        sympy_substituted
    )

    single_var_dict = {}

    for key, value in monomial_powers_substituted.dict.items():
        assert key[1] not in single_var_dict
        single_var_dict[key[1]] = value

    return UnivariatePoly(single_var_dict)

minw()

Get the minimum weight term and its coefficient.

Returns:

Type Description
Tuple[Any, int]

Tuple containing the minimum power and its coefficient.

Source code in planqtn/poly.py
256
257
258
259
260
261
262
263
264
def minw(self) -> Tuple[Any, int]:
    """Get the minimum weight term and its coefficient.

    Returns:
        Tuple containing the minimum power and its coefficient.
    """
    min_w = min(self.dict.keys())
    min_coeff = self.dict[min_w]
    return min_w, min_coeff

normalize(verbose=False)

Normalize the polynomial by dividing by the constant term if it's greater than 1.

Parameters:

Name Type Description Default
verbose bool

If True, print normalization information.

False

Returns:

Name Type Description
UnivariatePoly UnivariatePoly

The normalized polynomial.

Source code in planqtn/poly.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def normalize(self, verbose: bool = False) -> "UnivariatePoly":
    """Normalize the polynomial by dividing by the constant term if it's greater than 1.

    Args:
        verbose: If True, print normalization information.

    Returns:
        UnivariatePoly: The normalized polynomial.
    """
    if 0 in self.dict and self.dict[0] > 1:
        if verbose:
            print(f"normalizing WEP by 1/{self.dict[0]}")
        return self / self.dict[0]
    return self

truncate_inplace(n)

Truncate the polynomial to terms with power <= n in-place.

Parameters:

Name Type Description Default
n int

Maximum power to keep in the polynomial.

required
Source code in planqtn/poly.py
355
356
357
358
359
360
361
def truncate_inplace(self, n: int) -> None:
    """Truncate the polynomial to terms with power <= n in-place.

    Args:
        n: Maximum power to keep in the polynomial.
    """
    self.dict = {k: v for k, v in self.dict.items() if k <= n}

Stabilizer tensor enumerator module.

The unit of the tensor network is a stabilizer code encoding tensor (quantum lego), represented by the StabilizerCodeTensorEnumerator defined by a parity check matrix.

The main methods are: - stabilizer_enumerator_polynomial: Brute force calculation of the stabilizer enumerator polynomial for the stabilizer code. - trace_with_stopper: Traces the lego leg with a stopper. - conjoin: Conjoins two lego pieces into a new lego piece. - self_trace: Traces a leg with itself. - with_coset_flipped_legs: Adds coset flipped legs to the lego piece. - tensor_with: Tensor product of two lego pieces.

TensorId = str | int | Tuple[int, int] module-attribute

The tensor id can be a string, an integer, or a tuple of two integers.

TensorLeg = Tuple[TensorId, int] module-attribute

The tensor leg is a tuple of a tensor id and a leg index.

TensorEnumeratorKey = Tuple[int, ...] module-attribute

The tensor enumerator key is a tuple of integers.

TensorEnumerator = Dict[TensorEnumeratorKey, UnivariatePoly] module-attribute

The tensor enumerator is a dictionary of tuples of integers and univariate polynomials.

The planqtn.networks package

The planqtn.networks module contains layouts for tensor network networks.

For universally applicable tensor network layouts, see:

For specific networks, see:

CompassCodeDualSurfaceCodeLayoutTN

Bases: SurfaceCodeTN

A tensor network representation of compass codes using dual surface code layout.

This class implements a compass code using the dual doubled surface code equivalence described by Cao & Lackey in the expansion pack paper. The compass code is constructed by applying gauge operations to a surface code based on a coloring pattern.

Parameters:

Name Type Description Default
coloring ndarray

Array specifying the coloring pattern for the compass code.

required
lego Callable[[TensorId], GF2]

Function that returns the lego tensor for each node.

lambda node: encoding_tensor_512
coset_error Optional[GF2]

Optional coset error for weight enumerator calculations.

None
truncate_length Optional[int]

Optional maximum weight for truncating enumerators.

None
Source code in planqtn/networks/compass_code.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class CompassCodeDualSurfaceCodeLayoutTN(SurfaceCodeTN):
    """A tensor network representation of compass codes using dual surface code layout.

    This class implements a compass code using the dual doubled surface code equivalence
    described by Cao & Lackey in the expansion pack paper. The compass code is constructed
    by applying gauge operations to a surface code based on a coloring pattern.

    Args:
        coloring: Array specifying the coloring pattern for the compass code.
        lego: Function that returns the lego tensor for each node.
        coset_error: Optional coset error for weight enumerator calculations.
        truncate_length: Optional maximum weight for truncating enumerators.
    """

    def __init__(
        self,
        coloring: np.ndarray,
        *,
        lego: Callable[[TensorId], GF2] = lambda node: Legos.encoding_tensor_512,
        coset_error: Optional[GF2] = None,
        truncate_length: Optional[int] = None,
    ):
        """Create a square compass code based on the coloring.

        Creates a compass code using the dual doubled surface code equivalence
        described by Cao & Lackey in the expansion pack paper.

        Args:
            coloring: Array specifying the coloring pattern for the compass code.
            lego: Function that returns the lego tensor for each node.
            coset_error: Optional coset error for weight enumerator calculations.
            truncate_length: Optional maximum weight for truncating enumerators.
        """
        # See d3_compass_code_numbering.png for numbering - for an (r,c) qubit in the compass code,
        # the (2r, 2c) is the coordinate of the lego in the dual surface code.
        d = len(coloring) + 1
        super().__init__(d=d, lego=lego, truncate_length=truncate_length)
        gauge_idxs = [
            (r, c) for r in range(1, 2 * d - 1, 2) for c in range(1, 2 * d - 1, 2)
        ]
        for tensor_id, color in zip(gauge_idxs, np.reshape(coloring, (d - 1) ** 2)):
            self.nodes[tensor_id] = self.nodes[tensor_id].trace_with_stopper(
                Legos.stopper_z if color == 2 else Legos.stopper_x, 4
            )

        self._q_to_node = [(2 * r, 2 * c) for c in range(d) for r in range(d)]
        self.n = d * d
        self.coloring = coloring

        self.set_coset(
            coset_error if coset_error is not None else GF2.Zeros(2 * self.n)
        )

__init__(coloring, *, lego=lambda node: Legos.encoding_tensor_512, coset_error=None, truncate_length=None)

Create a square compass code based on the coloring.

Creates a compass code using the dual doubled surface code equivalence described by Cao & Lackey in the expansion pack paper.

Parameters:

Name Type Description Default
coloring ndarray

Array specifying the coloring pattern for the compass code.

required
lego Callable[[TensorId], GF2]

Function that returns the lego tensor for each node.

lambda node: encoding_tensor_512
coset_error Optional[GF2]

Optional coset error for weight enumerator calculations.

None
truncate_length Optional[int]

Optional maximum weight for truncating enumerators.

None
Source code in planqtn/networks/compass_code.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    coloring: np.ndarray,
    *,
    lego: Callable[[TensorId], GF2] = lambda node: Legos.encoding_tensor_512,
    coset_error: Optional[GF2] = None,
    truncate_length: Optional[int] = None,
):
    """Create a square compass code based on the coloring.

    Creates a compass code using the dual doubled surface code equivalence
    described by Cao & Lackey in the expansion pack paper.

    Args:
        coloring: Array specifying the coloring pattern for the compass code.
        lego: Function that returns the lego tensor for each node.
        coset_error: Optional coset error for weight enumerator calculations.
        truncate_length: Optional maximum weight for truncating enumerators.
    """
    # See d3_compass_code_numbering.png for numbering - for an (r,c) qubit in the compass code,
    # the (2r, 2c) is the coordinate of the lego in the dual surface code.
    d = len(coloring) + 1
    super().__init__(d=d, lego=lego, truncate_length=truncate_length)
    gauge_idxs = [
        (r, c) for r in range(1, 2 * d - 1, 2) for c in range(1, 2 * d - 1, 2)
    ]
    for tensor_id, color in zip(gauge_idxs, np.reshape(coloring, (d - 1) ** 2)):
        self.nodes[tensor_id] = self.nodes[tensor_id].trace_with_stopper(
            Legos.stopper_z if color == 2 else Legos.stopper_x, 4
        )

    self._q_to_node = [(2 * r, 2 * c) for c in range(d) for r in range(d)]
    self.n = d * d
    self.coloring = coloring

    self.set_coset(
        coset_error if coset_error is not None else GF2.Zeros(2 * self.n)
    )

CssTannerCodeTN

Bases: TensorNetwork

A tensor network representation of CSS codes using Tanner graph structure.

This class constructs a tensor network from X and Z parity check matrices (Hx and Hz), representing a CSS code. The tensor network connects qubit tensors to check tensors according to the non-zero entries in the parity check matrices.

Source code in planqtn/networks/css_tanner_code.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class CssTannerCodeTN(TensorNetwork):
    """A tensor network representation of CSS codes using Tanner graph structure.

    This class constructs a tensor network from X and Z parity check matrices (Hx and Hz),
    representing a CSS code. The tensor network connects qubit tensors to check tensors
    according to the non-zero entries in the parity check matrices.
    """

    def __init__(
        self,
        hx: np.ndarray,
        hz: np.ndarray,
    ):
        """Construct a CSS code tensor network from X and Z parity check matrices.

        Args:
            hx: X-type parity check matrix.
            hz: Z-type parity check matrix.
        """
        self.n: int = hx.shape[1]
        self.r: int = hz.shape[0]  # rz, but not used

        q_tensors: List[StabilizerCodeTensorEnumerator] = []
        traces: List[Trace] = []
        self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

        for q in range(self.n):
            x_stabs = np.nonzero(hx[:, q])[0]
            n_x_legs = len(x_stabs)
            z_stabs = np.nonzero(hz[:, q])[0]
            n_z_legs = len(z_stabs)

            h0 = StabilizerCodeTensorEnumerator(
                Legos.h,
                tensor_id=f"q{q}.h0",
                annotation=LegoAnnotation(type=LegoType.H, short_name=f"h0{q}"),
            )
            h1 = StabilizerCodeTensorEnumerator(
                Legos.h,
                tensor_id=f"q{q}.h1",
                annotation=LegoAnnotation(type=LegoType.H, short_name=f"h1{q}"),
            )

            x = StabilizerCodeTensorEnumerator(
                Legos.x_rep_code(2 + n_x_legs),
                tensor_id=f"q{q}.x",
                annotation=LegoAnnotation(type=LegoType.XREP, short_name=f"x{q}"),
            )

            z = StabilizerCodeTensorEnumerator(
                Legos.x_rep_code(2 + n_z_legs),
                tensor_id=f"q{q}.z",
                annotation=LegoAnnotation(type=LegoType.XREP, short_name=f"z{q}"),
            )

            # leg numbering for the spiders: 0 for logical, 1 for physical,
            # rest is to the check nodes
            # going left to right:
            # I -> h0 -> Z [leg0  (legs to Z check 2...n_z_legs) leg1] -> h1 ->
            # X[leg0  (legs to X check 2...n_x_legs) -> dangling physical leg 1] -> x
            i_stopper = StabilizerCodeTensorEnumerator(
                Legos.stopper_i,
                tensor_id=f"q{q}.id",
                annotation=LegoAnnotation(type=LegoType.STOPPER_I, short_name=f"id{q}"),
            )
            q_tensors.append(i_stopper)
            q_tensors.append(h0)
            q_tensors.append(z)
            q_tensors.append(h1)
            q_tensors.append(x)

            traces.append(
                (
                    i_stopper.tensor_id,
                    h0.tensor_id,
                    [(f"q{q}.id", 0)],
                    [(f"q{q}.h0", 0)],
                )
            )
            traces.append(
                (
                    h0.tensor_id,
                    z.tensor_id,
                    [(h0.tensor_id, 1)],
                    [(z.tensor_id, 0)],
                )
            )
            traces.append(
                (
                    h1.tensor_id,
                    z.tensor_id,
                    [(h1.tensor_id, 0)],
                    [(z.tensor_id, 1)],
                )
            )
            traces.append(
                (
                    h1.tensor_id,
                    x.tensor_id,
                    [(h1.tensor_id, 1)],
                    [(x.tensor_id, 0)],
                )
            )

        q_legs = [2] * self.n
        gx_tensors = []
        for i, gx in enumerate(hx):
            qs = np.nonzero(gx)[0].astype(int)
            g_tensor = StabilizerCodeTensorEnumerator(
                Legos.z_rep_code(len(qs)),
                f"x{i}",
                annotation=LegoAnnotation(
                    type=LegoType.ZREP,
                    short_name=f"x{i}",
                ),
            )
            # print(f"=== x tensor {g_tensor.idx} -> {qs} === ")

            gx_tensors.append(g_tensor)
            for g_leg, q in enumerate(qs):
                x_tensor_id = f"q{q}.x"
                traces.append(
                    (
                        g_tensor.tensor_id,
                        x_tensor_id,
                        [(g_tensor.tensor_id, g_leg)],
                        [(x_tensor_id, q_legs[q])],
                    )
                )
                q_legs[q] += 1
        gz_tensors = []
        q_legs = [2] * self.n

        for i, gz in enumerate(hz):
            qs = np.nonzero(gz)[0].astype(int)
            g_tensor = StabilizerCodeTensorEnumerator(
                Legos.z_rep_code(len(qs)),
                f"z{i}",
                annotation=LegoAnnotation(
                    type=LegoType.ZREP,
                    short_name=f"z{i}",
                ),
            )
            gz_tensors.append(g_tensor)
            for g_leg, q in enumerate(qs):
                z_tensor_id = f"q{q}.z"
                traces.append(
                    (
                        g_tensor.tensor_id,
                        z_tensor_id,
                        [(g_tensor.tensor_id, g_leg)],
                        [(z_tensor_id, q_legs[q])],
                    )
                )
                q_legs[q] += 1
        super().__init__(q_tensors + gx_tensors + gz_tensors)

        for t in traces:
            self.self_trace(*t)

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns:
            int: Total number of qubits represented by this tensor network.
        """
        return self.n

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        Args:
            q: Global qubit index.

        Returns:
            Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
        """
        return f"q{q}.x", (f"q{q}.x", 1)

__init__(hx, hz)

Construct a CSS code tensor network from X and Z parity check matrices.

Parameters:

Name Type Description Default
hx ndarray

X-type parity check matrix.

required
hz ndarray

Z-type parity check matrix.

required
Source code in planqtn/networks/css_tanner_code.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def __init__(
    self,
    hx: np.ndarray,
    hz: np.ndarray,
):
    """Construct a CSS code tensor network from X and Z parity check matrices.

    Args:
        hx: X-type parity check matrix.
        hz: Z-type parity check matrix.
    """
    self.n: int = hx.shape[1]
    self.r: int = hz.shape[0]  # rz, but not used

    q_tensors: List[StabilizerCodeTensorEnumerator] = []
    traces: List[Trace] = []
    self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

    for q in range(self.n):
        x_stabs = np.nonzero(hx[:, q])[0]
        n_x_legs = len(x_stabs)
        z_stabs = np.nonzero(hz[:, q])[0]
        n_z_legs = len(z_stabs)

        h0 = StabilizerCodeTensorEnumerator(
            Legos.h,
            tensor_id=f"q{q}.h0",
            annotation=LegoAnnotation(type=LegoType.H, short_name=f"h0{q}"),
        )
        h1 = StabilizerCodeTensorEnumerator(
            Legos.h,
            tensor_id=f"q{q}.h1",
            annotation=LegoAnnotation(type=LegoType.H, short_name=f"h1{q}"),
        )

        x = StabilizerCodeTensorEnumerator(
            Legos.x_rep_code(2 + n_x_legs),
            tensor_id=f"q{q}.x",
            annotation=LegoAnnotation(type=LegoType.XREP, short_name=f"x{q}"),
        )

        z = StabilizerCodeTensorEnumerator(
            Legos.x_rep_code(2 + n_z_legs),
            tensor_id=f"q{q}.z",
            annotation=LegoAnnotation(type=LegoType.XREP, short_name=f"z{q}"),
        )

        # leg numbering for the spiders: 0 for logical, 1 for physical,
        # rest is to the check nodes
        # going left to right:
        # I -> h0 -> Z [leg0  (legs to Z check 2...n_z_legs) leg1] -> h1 ->
        # X[leg0  (legs to X check 2...n_x_legs) -> dangling physical leg 1] -> x
        i_stopper = StabilizerCodeTensorEnumerator(
            Legos.stopper_i,
            tensor_id=f"q{q}.id",
            annotation=LegoAnnotation(type=LegoType.STOPPER_I, short_name=f"id{q}"),
        )
        q_tensors.append(i_stopper)
        q_tensors.append(h0)
        q_tensors.append(z)
        q_tensors.append(h1)
        q_tensors.append(x)

        traces.append(
            (
                i_stopper.tensor_id,
                h0.tensor_id,
                [(f"q{q}.id", 0)],
                [(f"q{q}.h0", 0)],
            )
        )
        traces.append(
            (
                h0.tensor_id,
                z.tensor_id,
                [(h0.tensor_id, 1)],
                [(z.tensor_id, 0)],
            )
        )
        traces.append(
            (
                h1.tensor_id,
                z.tensor_id,
                [(h1.tensor_id, 0)],
                [(z.tensor_id, 1)],
            )
        )
        traces.append(
            (
                h1.tensor_id,
                x.tensor_id,
                [(h1.tensor_id, 1)],
                [(x.tensor_id, 0)],
            )
        )

    q_legs = [2] * self.n
    gx_tensors = []
    for i, gx in enumerate(hx):
        qs = np.nonzero(gx)[0].astype(int)
        g_tensor = StabilizerCodeTensorEnumerator(
            Legos.z_rep_code(len(qs)),
            f"x{i}",
            annotation=LegoAnnotation(
                type=LegoType.ZREP,
                short_name=f"x{i}",
            ),
        )
        # print(f"=== x tensor {g_tensor.idx} -> {qs} === ")

        gx_tensors.append(g_tensor)
        for g_leg, q in enumerate(qs):
            x_tensor_id = f"q{q}.x"
            traces.append(
                (
                    g_tensor.tensor_id,
                    x_tensor_id,
                    [(g_tensor.tensor_id, g_leg)],
                    [(x_tensor_id, q_legs[q])],
                )
            )
            q_legs[q] += 1
    gz_tensors = []
    q_legs = [2] * self.n

    for i, gz in enumerate(hz):
        qs = np.nonzero(gz)[0].astype(int)
        g_tensor = StabilizerCodeTensorEnumerator(
            Legos.z_rep_code(len(qs)),
            f"z{i}",
            annotation=LegoAnnotation(
                type=LegoType.ZREP,
                short_name=f"z{i}",
            ),
        )
        gz_tensors.append(g_tensor)
        for g_leg, q in enumerate(qs):
            z_tensor_id = f"q{q}.z"
            traces.append(
                (
                    g_tensor.tensor_id,
                    z_tensor_id,
                    [(g_tensor.tensor_id, g_leg)],
                    [(z_tensor_id, q_legs[q])],
                )
            )
            q_legs[q] += 1
    super().__init__(q_tensors + gx_tensors + gz_tensors)

    for t in traces:
        self.self_trace(*t)

n_qubits()

Get the total number of qubits in the tensor network.

Returns:

Name Type Description
int int

Total number of qubits represented by this tensor network.

Source code in planqtn/networks/css_tanner_code.py
181
182
183
184
185
186
187
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns:
        int: Total number of qubits represented by this tensor network.
    """
    return self.n

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Type Description
Tuple[TensorId, TensorLeg]

Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.

Source code in planqtn/networks/css_tanner_code.py
189
190
191
192
193
194
195
196
197
198
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    Args:
        q: Global qubit index.

    Returns:
        Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
    """
    return f"q{q}.x", (f"q{q}.x", 1)

RotatedSurfaceCodeTN

Bases: TensorNetwork

A tensor network representation of rotated surface codes.

This class constructs a tensor network for a rotated surface code of distance d. The rotated surface code has a checkerboard pattern of X and Z stabilizers, with appropriate boundary conditions for the rotated geometry.

Source code in planqtn/networks/rotated_surface_code.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class RotatedSurfaceCodeTN(TensorNetwork):
    """A tensor network representation of rotated surface codes.

    This class constructs a tensor network for a rotated surface code of distance d.
    The rotated surface code has a checkerboard pattern of X and Z stabilizers,
    with appropriate boundary conditions for the rotated geometry.
    """

    def __init__(
        self,
        d: int,
        lego: Callable[[TensorId], GF2] = lambda i: Legos.encoding_tensor_512,
        coset_error: Optional[GF2] = None,
        truncate_length: Optional[int] = None,
    ):
        """Construct a rotated surface code tensor network.

        Args:
            d: Distance of the surface code.
            lego: Function that returns the lego tensor for each node.
            coset_error: Optional coset error for weight enumerator calculations.
            truncate_length: Optional maximum weight for truncating enumerators.
        """
        nodes: Dict[TensorId, StabilizerCodeTensorEnumerator] = {
            (r, c): StabilizerCodeTensorEnumerator(
                lego((r, c)),
                tensor_id=(r, c),
            )
            # col major ordering
            for r in range(d)
            for c in range(d)
        }

        for c in range(d):
            # top Z boundary (X type checks, Z type logical)
            nodes[(0, c)] = nodes[(0, c)].trace_with_stopper(
                Legos.stopper_x, 3 if c % 2 == 0 else 0
            )
            # bottom Z boundary (X type checks, Z type logical)
            nodes[(d - 1, c)] = nodes[(d - 1, c)].trace_with_stopper(
                Legos.stopper_x, 1 if c % 2 == 0 else 2
            )

        for r in range(d):
            # left X boundary (Z type checks, X type logical)
            nodes[r, 0] = nodes[(r, 0)].trace_with_stopper(
                Legos.stopper_z, 0 if r % 2 == 0 else 1
            )
            # right X boundary (Z type checks, X type logical)
            nodes[(r, d - 1)] = nodes[(r, d - 1)].trace_with_stopper(
                Legos.stopper_z, 2 if r % 2 == 0 else 3
            )

        # for r in range(1,4):
        #     # bulk
        #     for c in range(1,4):

        super().__init__(nodes, truncate_length=truncate_length)

        for radius in range(1, d):
            for i in range(radius + 1):
                # extending the right boundary
                self.self_trace(
                    (i, radius - 1),
                    (i, radius),
                    [3 if (i + radius) % 2 == 0 else 2],
                    [0 if (i + radius) % 2 == 0 else 1],
                )
                if 0 < i < radius:
                    self.self_trace(
                        (i - 1, radius),
                        (i, radius),
                        [2 if (i + radius) % 2 == 0 else 1],
                        [3 if (i + radius) % 2 == 0 else 0],
                    )
                # extending the bottom boundary
                self.self_trace(
                    (radius - 1, i),
                    (radius, i),
                    [2 if (i + radius) % 2 == 0 else 1],
                    [3 if (i + radius) % 2 == 0 else 0],
                )
                if 0 < i < radius:
                    self.self_trace(
                        (radius, i - 1),
                        (radius, i),
                        [3 if (i + radius) % 2 == 0 else 2],
                        [0 if (i + radius) % 2 == 0 else 1],
                    )
        self.n = d * d
        self.d = d

        if coset_error is not None:
            self.set_coset(coset_error=coset_error)

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        The rotated surface code uses column major ordering.

        Args:
            q: Global qubit index.

        Returns:
            Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
        """
        # col major ordering
        node = (q % self.d, q // self.d)
        return node, (node, 4)

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns:
            int: Total number of qubits represented by this tensor network.
        """
        return self.n

__init__(d, lego=lambda i: Legos.encoding_tensor_512, coset_error=None, truncate_length=None)

Construct a rotated surface code tensor network.

Parameters:

Name Type Description Default
d int

Distance of the surface code.

required
lego Callable[[TensorId], GF2]

Function that returns the lego tensor for each node.

lambda i: encoding_tensor_512
coset_error Optional[GF2]

Optional coset error for weight enumerator calculations.

None
truncate_length Optional[int]

Optional maximum weight for truncating enumerators.

None
Source code in planqtn/networks/rotated_surface_code.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __init__(
    self,
    d: int,
    lego: Callable[[TensorId], GF2] = lambda i: Legos.encoding_tensor_512,
    coset_error: Optional[GF2] = None,
    truncate_length: Optional[int] = None,
):
    """Construct a rotated surface code tensor network.

    Args:
        d: Distance of the surface code.
        lego: Function that returns the lego tensor for each node.
        coset_error: Optional coset error for weight enumerator calculations.
        truncate_length: Optional maximum weight for truncating enumerators.
    """
    nodes: Dict[TensorId, StabilizerCodeTensorEnumerator] = {
        (r, c): StabilizerCodeTensorEnumerator(
            lego((r, c)),
            tensor_id=(r, c),
        )
        # col major ordering
        for r in range(d)
        for c in range(d)
    }

    for c in range(d):
        # top Z boundary (X type checks, Z type logical)
        nodes[(0, c)] = nodes[(0, c)].trace_with_stopper(
            Legos.stopper_x, 3 if c % 2 == 0 else 0
        )
        # bottom Z boundary (X type checks, Z type logical)
        nodes[(d - 1, c)] = nodes[(d - 1, c)].trace_with_stopper(
            Legos.stopper_x, 1 if c % 2 == 0 else 2
        )

    for r in range(d):
        # left X boundary (Z type checks, X type logical)
        nodes[r, 0] = nodes[(r, 0)].trace_with_stopper(
            Legos.stopper_z, 0 if r % 2 == 0 else 1
        )
        # right X boundary (Z type checks, X type logical)
        nodes[(r, d - 1)] = nodes[(r, d - 1)].trace_with_stopper(
            Legos.stopper_z, 2 if r % 2 == 0 else 3
        )

    # for r in range(1,4):
    #     # bulk
    #     for c in range(1,4):

    super().__init__(nodes, truncate_length=truncate_length)

    for radius in range(1, d):
        for i in range(radius + 1):
            # extending the right boundary
            self.self_trace(
                (i, radius - 1),
                (i, radius),
                [3 if (i + radius) % 2 == 0 else 2],
                [0 if (i + radius) % 2 == 0 else 1],
            )
            if 0 < i < radius:
                self.self_trace(
                    (i - 1, radius),
                    (i, radius),
                    [2 if (i + radius) % 2 == 0 else 1],
                    [3 if (i + radius) % 2 == 0 else 0],
                )
            # extending the bottom boundary
            self.self_trace(
                (radius - 1, i),
                (radius, i),
                [2 if (i + radius) % 2 == 0 else 1],
                [3 if (i + radius) % 2 == 0 else 0],
            )
            if 0 < i < radius:
                self.self_trace(
                    (radius, i - 1),
                    (radius, i),
                    [3 if (i + radius) % 2 == 0 else 2],
                    [0 if (i + radius) % 2 == 0 else 1],
                )
    self.n = d * d
    self.d = d

    if coset_error is not None:
        self.set_coset(coset_error=coset_error)

n_qubits()

Get the total number of qubits in the tensor network.

Returns:

Name Type Description
int int

Total number of qubits represented by this tensor network.

Source code in planqtn/networks/rotated_surface_code.py
130
131
132
133
134
135
136
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns:
        int: Total number of qubits represented by this tensor network.
    """
    return self.n

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

The rotated surface code uses column major ordering.

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Type Description
Tuple[TensorId, TensorLeg]

Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.

Source code in planqtn/networks/rotated_surface_code.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    The rotated surface code uses column major ordering.

    Args:
        q: Global qubit index.

    Returns:
        Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
    """
    # col major ordering
    node = (q % self.d, q // self.d)
    return node, (node, 4)

StabilizerMeasurementStatePrepTN

Bases: TensorNetwork

Measurement-based state preparation circuit layout.

A universal tensor network layout based on the measurement-based state preparation circuit layout described in the following work:

Cao, ChunJun, Michael J. Gullans, Brad Lackey, and Zitao Wang. 2024. “Quantum Lego Expansion Pack: Enumerators from Tensor Networks.” PRX Quantum 5 (3): 030313. https://doi.org/10.1103/PRXQuantum.5.030313.

Source code in planqtn/networks/stabilizer_measurement_state_prep.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class StabilizerMeasurementStatePrepTN(TensorNetwork):
    """Measurement-based state preparation circuit layout.

    A universal tensor network layout based on the measurement-based state preparation
    circuit layout described in the following work:

    Cao, ChunJun, Michael J. Gullans, Brad Lackey, and Zitao Wang. 2024.
    “Quantum Lego Expansion Pack: Enumerators from Tensor Networks.”
    PRX Quantum 5 (3): 030313. https://doi.org/10.1103/PRXQuantum.5.030313.
    """

    def __init__(self, parity_check_matrix: np.ndarray):
        """Construct a stabilizer measurement state preparation tensor network.

        Args:
            parity_check_matrix: The parity check matrix of the stabilizer code.

        Raises:
            ValueError: If the parity check matrix is not symplectic.
            NotImplementedError: If Y stabilizers are in the parity check matrix.
                It is not implemented yet.
        """
        self.parity_check_matrix = parity_check_matrix
        if parity_check_matrix.shape[1] % 2 == 1:
            raise ValueError(f"Not a symplectic matrix: {parity_check_matrix}")

        r = parity_check_matrix.shape[0]
        n = parity_check_matrix.shape[1] // 2
        traces = []

        self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

        checks = []
        check_stoppers = []
        for i in range(r):
            weight = np.count_nonzero(parity_check_matrix[i])
            check = StabilizerCodeTensorEnumerator(
                h=Legos.z_rep_code(weight + 2),
                tensor_id=f"check{i}",
                annotation=LegoAnnotation(
                    type=LegoType.ZREP,
                    description="check{i}",
                    name=f"check{i}",
                    x=1 + i,
                    y=0,
                ),
            )
            x_state_prep = StabilizerCodeTensorEnumerator(
                h=Legos.stopper_x,
                tensor_id=f"x_state_prep{i}",
                annotation=LegoAnnotation(
                    type=LegoType.STOPPER_X,
                    description="xsp{i}",
                    name=f"xsp{i}",
                    x=1 + i - 0.25,
                    y=0,
                ),
            )
            check_stoppers.append(x_state_prep)
            x_meas = StabilizerCodeTensorEnumerator(
                h=Legos.stopper_x,
                tensor_id=f"x_meas{i}",
                annotation=LegoAnnotation(
                    type=LegoType.STOPPER_X,
                    description=f"xmeas{i}",
                    name=f"xmeas{i}",
                    x=1 + i + 0.25,
                    y=0,
                ),
            )
            check_stoppers.append(x_meas)

            traces.append(
                (
                    x_state_prep.tensor_id,
                    check.tensor_id,
                    [(x_state_prep.tensor_id, 0)],
                    [(check.tensor_id, 0)],
                )
            )
            traces.append(
                (
                    x_meas.tensor_id,
                    check.tensor_id,
                    [(x_meas.tensor_id, 0)],
                    [(check.tensor_id, 1)],
                )
            )
            checks.append(check)

        next_check_legs = [2] * r
        q_tensors = []
        op_tensors = []

        # for each qubit we create merged tensors across all checks
        for q in range(n):
            q_logical_id = StabilizerCodeTensorEnumerator(
                h=Legos.stopper_i,
                tensor_id=f"ql{q}",
                annotation=LegoAnnotation(
                    type=LegoType.STOPPER_I,
                    description="stopper_i",
                    name=f"stopper_i{q}",
                    x=0,
                    y=1 + q,
                ),
            )
            q_tensors.append(q_logical_id)
            physical_leg = (q_logical_id.tensor_id, (q_logical_id.tensor_id, 0))
            for i in range(r):
                op = tuple(parity_check_matrix[i, (q, q + n)])

                if op == (0, 0):
                    continue

                if op == (1, 0):
                    x_check = StabilizerCodeTensorEnumerator(
                        h=Legos.x_rep_code(3),
                        tensor_id=f"q{q}.x{i}",
                        annotation=LegoAnnotation(
                            type=LegoType.XREP,
                            description="x",
                            name=f"x{q}.{i}",
                            x=1 + i,
                            y=1 + q,
                        ),
                    )
                    op_tensors.append(x_check)

                    traces.append(
                        (
                            physical_leg[0],
                            x_check.tensor_id,
                            [physical_leg[1]],
                            [(x_check.tensor_id, 0)],
                        )
                    )

                    traces.append(
                        (
                            x_check.tensor_id,
                            checks[i].tensor_id,
                            [(x_check.tensor_id, 1)],
                            [(checks[i].tensor_id, next_check_legs[i])],
                        )
                    )
                    next_check_legs[i] += 1
                    physical_leg = (x_check.tensor_id, (x_check.tensor_id, 2))

                elif op == (0, 1):
                    z_check = StabilizerCodeTensorEnumerator(
                        h=Legos.z_rep_code(3),
                        tensor_id=f"q{q}.z{i}",
                        annotation=LegoAnnotation(
                            type=LegoType.ZREP,
                            description="z",
                            name=f"z{q}.{i}",
                            x=1 + i,
                            y=1 + q,
                        ),
                    )
                    op_tensors.append(z_check)

                    traces.append(
                        (
                            physical_leg[0],
                            z_check.tensor_id,
                            [physical_leg[1]],
                            [(z_check.tensor_id, 0)],
                        )
                    )
                    h = StabilizerCodeTensorEnumerator(
                        h=Legos.h,
                        tensor_id=f"q{q}.hz{i}",
                        annotation=LegoAnnotation(
                            type=LegoType.H,
                            description="h",
                            name=f"h{q}.{i}",
                            x=1 + i,
                            y=1 + q - 0.5,
                        ),
                    )
                    op_tensors.append(h)

                    traces.append(
                        (
                            z_check.tensor_id,
                            h.tensor_id,
                            [(z_check.tensor_id, 1)],
                            [(h.tensor_id, 0)],
                        )
                    )
                    traces.append(
                        (
                            h.tensor_id,
                            checks[i].tensor_id,
                            [(h.tensor_id, 1)],
                            [(checks[i].tensor_id, next_check_legs[i])],
                        )
                    )
                    next_check_legs[i] += 1
                    physical_leg = (z_check.tensor_id, (z_check.tensor_id, 2))

                else:
                    raise NotImplementedError("Y stabilizer is not implemented yet...")
            self.q_to_leg_and_node.append(physical_leg)

        super().__init__(
            nodes={
                n.tensor_id: n for n in q_tensors + checks + op_tensors + check_stoppers
            }
        )

        for t in traces:
            self.self_trace(*t)

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns:
            int: Total number of qubits represented by this tensor network.
        """
        return int(self.parity_check_matrix.shape[1] // 2)

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        Args:
            q: Global qubit index.

        Returns:
            Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
        """
        return self.q_to_leg_and_node[q]

__init__(parity_check_matrix)

Construct a stabilizer measurement state preparation tensor network.

Parameters:

Name Type Description Default
parity_check_matrix ndarray

The parity check matrix of the stabilizer code.

required

Raises:

Type Description
ValueError

If the parity check matrix is not symplectic.

NotImplementedError

If Y stabilizers are in the parity check matrix. It is not implemented yet.

Source code in planqtn/networks/stabilizer_measurement_state_prep.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(self, parity_check_matrix: np.ndarray):
    """Construct a stabilizer measurement state preparation tensor network.

    Args:
        parity_check_matrix: The parity check matrix of the stabilizer code.

    Raises:
        ValueError: If the parity check matrix is not symplectic.
        NotImplementedError: If Y stabilizers are in the parity check matrix.
            It is not implemented yet.
    """
    self.parity_check_matrix = parity_check_matrix
    if parity_check_matrix.shape[1] % 2 == 1:
        raise ValueError(f"Not a symplectic matrix: {parity_check_matrix}")

    r = parity_check_matrix.shape[0]
    n = parity_check_matrix.shape[1] // 2
    traces = []

    self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

    checks = []
    check_stoppers = []
    for i in range(r):
        weight = np.count_nonzero(parity_check_matrix[i])
        check = StabilizerCodeTensorEnumerator(
            h=Legos.z_rep_code(weight + 2),
            tensor_id=f"check{i}",
            annotation=LegoAnnotation(
                type=LegoType.ZREP,
                description="check{i}",
                name=f"check{i}",
                x=1 + i,
                y=0,
            ),
        )
        x_state_prep = StabilizerCodeTensorEnumerator(
            h=Legos.stopper_x,
            tensor_id=f"x_state_prep{i}",
            annotation=LegoAnnotation(
                type=LegoType.STOPPER_X,
                description="xsp{i}",
                name=f"xsp{i}",
                x=1 + i - 0.25,
                y=0,
            ),
        )
        check_stoppers.append(x_state_prep)
        x_meas = StabilizerCodeTensorEnumerator(
            h=Legos.stopper_x,
            tensor_id=f"x_meas{i}",
            annotation=LegoAnnotation(
                type=LegoType.STOPPER_X,
                description=f"xmeas{i}",
                name=f"xmeas{i}",
                x=1 + i + 0.25,
                y=0,
            ),
        )
        check_stoppers.append(x_meas)

        traces.append(
            (
                x_state_prep.tensor_id,
                check.tensor_id,
                [(x_state_prep.tensor_id, 0)],
                [(check.tensor_id, 0)],
            )
        )
        traces.append(
            (
                x_meas.tensor_id,
                check.tensor_id,
                [(x_meas.tensor_id, 0)],
                [(check.tensor_id, 1)],
            )
        )
        checks.append(check)

    next_check_legs = [2] * r
    q_tensors = []
    op_tensors = []

    # for each qubit we create merged tensors across all checks
    for q in range(n):
        q_logical_id = StabilizerCodeTensorEnumerator(
            h=Legos.stopper_i,
            tensor_id=f"ql{q}",
            annotation=LegoAnnotation(
                type=LegoType.STOPPER_I,
                description="stopper_i",
                name=f"stopper_i{q}",
                x=0,
                y=1 + q,
            ),
        )
        q_tensors.append(q_logical_id)
        physical_leg = (q_logical_id.tensor_id, (q_logical_id.tensor_id, 0))
        for i in range(r):
            op = tuple(parity_check_matrix[i, (q, q + n)])

            if op == (0, 0):
                continue

            if op == (1, 0):
                x_check = StabilizerCodeTensorEnumerator(
                    h=Legos.x_rep_code(3),
                    tensor_id=f"q{q}.x{i}",
                    annotation=LegoAnnotation(
                        type=LegoType.XREP,
                        description="x",
                        name=f"x{q}.{i}",
                        x=1 + i,
                        y=1 + q,
                    ),
                )
                op_tensors.append(x_check)

                traces.append(
                    (
                        physical_leg[0],
                        x_check.tensor_id,
                        [physical_leg[1]],
                        [(x_check.tensor_id, 0)],
                    )
                )

                traces.append(
                    (
                        x_check.tensor_id,
                        checks[i].tensor_id,
                        [(x_check.tensor_id, 1)],
                        [(checks[i].tensor_id, next_check_legs[i])],
                    )
                )
                next_check_legs[i] += 1
                physical_leg = (x_check.tensor_id, (x_check.tensor_id, 2))

            elif op == (0, 1):
                z_check = StabilizerCodeTensorEnumerator(
                    h=Legos.z_rep_code(3),
                    tensor_id=f"q{q}.z{i}",
                    annotation=LegoAnnotation(
                        type=LegoType.ZREP,
                        description="z",
                        name=f"z{q}.{i}",
                        x=1 + i,
                        y=1 + q,
                    ),
                )
                op_tensors.append(z_check)

                traces.append(
                    (
                        physical_leg[0],
                        z_check.tensor_id,
                        [physical_leg[1]],
                        [(z_check.tensor_id, 0)],
                    )
                )
                h = StabilizerCodeTensorEnumerator(
                    h=Legos.h,
                    tensor_id=f"q{q}.hz{i}",
                    annotation=LegoAnnotation(
                        type=LegoType.H,
                        description="h",
                        name=f"h{q}.{i}",
                        x=1 + i,
                        y=1 + q - 0.5,
                    ),
                )
                op_tensors.append(h)

                traces.append(
                    (
                        z_check.tensor_id,
                        h.tensor_id,
                        [(z_check.tensor_id, 1)],
                        [(h.tensor_id, 0)],
                    )
                )
                traces.append(
                    (
                        h.tensor_id,
                        checks[i].tensor_id,
                        [(h.tensor_id, 1)],
                        [(checks[i].tensor_id, next_check_legs[i])],
                    )
                )
                next_check_legs[i] += 1
                physical_leg = (z_check.tensor_id, (z_check.tensor_id, 2))

            else:
                raise NotImplementedError("Y stabilizer is not implemented yet...")
        self.q_to_leg_and_node.append(physical_leg)

    super().__init__(
        nodes={
            n.tensor_id: n for n in q_tensors + checks + op_tensors + check_stoppers
        }
    )

    for t in traces:
        self.self_trace(*t)

n_qubits()

Get the total number of qubits in the tensor network.

Returns:

Name Type Description
int int

Total number of qubits represented by this tensor network.

Source code in planqtn/networks/stabilizer_measurement_state_prep.py
239
240
241
242
243
244
245
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns:
        int: Total number of qubits represented by this tensor network.
    """
    return int(self.parity_check_matrix.shape[1] // 2)

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Type Description
Tuple[TensorId, TensorLeg]

Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.

Source code in planqtn/networks/stabilizer_measurement_state_prep.py
247
248
249
250
251
252
253
254
255
256
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    Args:
        q: Global qubit index.

    Returns:
        Tuple[TensorId, TensorLeg]: Node ID and leg that represent the qubit.
    """
    return self.q_to_leg_and_node[q]

StabilizerTannerCodeTN

Bases: TensorNetwork

A tensor network representation of stabilizer codes using Tanner graph structure.

This class constructs a tensor network from a parity check matrix H, where each row of H represents a stabilizer generator and each column represents a qubit. The tensor network is built by connecting check tensors to qubit tensors according to the non-zero entries in the parity check matrix.

Source code in planqtn/networks/stabilizer_tanner_code.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class StabilizerTannerCodeTN(TensorNetwork):
    """A tensor network representation of stabilizer codes using Tanner graph structure.

    This class constructs a tensor network from a parity check matrix H, where each
    row of H represents a stabilizer generator and each column represents a qubit.
    The tensor network is built by connecting check tensors to qubit tensors according
    to the non-zero entries in the parity check matrix.
    """

    def __init__(self, h: np.ndarray):
        """Construct a stabilizer Tanner code tensor network.

        Args:
            h: Parity check matrix in symplectic form (must have even number of columns).

        Raises:
            ValueError: If the parity check matrix is not symplectic.
        """
        self.parity_check_matrix = h
        if h.shape[1] % 2 == 1:
            raise ValueError(f"Not a symplectic matrix: {h}")

        r = h.shape[0]
        n = h.shape[1] // 2

        checks = []
        for i in range(r):
            weight = np.count_nonzero(h[i])
            check = StabilizerCodeTensorEnumerator(
                h=Legos.z_rep_code(weight + 2), tensor_id=f"check{i}"
            )
            check = check.trace_with_stopper(Legos.stopper_x, (f"check{i}", 0))
            check = check.trace_with_stopper(Legos.stopper_x, (f"check{i}", 1))
            checks.append(check)

        traces = []
        next_check_legs = [2] * r
        q_tensors = []
        self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

        # for each qubit we create merged tensors across all checks
        for q in range(n):
            q_tensor = StabilizerCodeTensorEnumerator(
                h=Legos.stopper_i, tensor_id=f"q{q}"
            )
            physical_leg = (f"q{q}", 0)
            for i in range(r):
                op = tuple(h[i, (q, q + n)])
                if op == (0, 0):
                    continue

                if op == (1, 0):
                    q_tensor = q_tensor.conjoin(
                        StabilizerCodeTensorEnumerator(
                            h=Legos.x_rep_code(3), tensor_id=f"q{q}.c{i}"
                        ),
                        [physical_leg],
                        [0],
                    )
                    traces.append(
                        (
                            q_tensor.tensor_id,
                            checks[i].tensor_id,
                            [(f"q{q}.c{i}", 1)],
                            [next_check_legs[i]],
                        )
                    )
                    next_check_legs[i] += 1
                    physical_leg = (f"q{q}.c{i}", 2)

                elif op == (0, 1):
                    q_tensor = q_tensor.conjoin(
                        StabilizerCodeTensorEnumerator(
                            h=Legos.z_rep_code(3), tensor_id=f"q{q}.z{i}"
                        ),
                        [physical_leg],
                        [0],
                    )
                    q_tensor = q_tensor.conjoin(
                        StabilizerCodeTensorEnumerator(
                            h=Legos.h, tensor_id=f"q{q}.c{i}"
                        ),
                        [(f"q{q}.z{i}", 1)],
                        [0],
                    )
                    traces.append(
                        (
                            q_tensor.tensor_id,
                            checks[i].tensor_id,
                            [(f"q{q}.c{i}", 1)],
                            [next_check_legs[i]],
                        )
                    )
                    next_check_legs[i] += 1
                    physical_leg = (f"q{q}.z{i}", 2)

                else:
                    raise ValueError("Y stabilizer is not implemented yet...")
            q_tensors.append(q_tensor)
            self.q_to_leg_and_node.append((physical_leg[0], physical_leg))

        super().__init__(nodes={n.tensor_id: n for n in q_tensors + checks})

        for t in traces:
            self.self_trace(*t)

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns:
            int: Total number of qubits represented by this tensor network.
        """
        return int(self.parity_check_matrix.shape[1] // 2)

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        Returns the tensor and leg for the given qubit index.

        Args:
            q: Global qubit index.

        Returns:
            Node ID: node id for the tensor in the network
            Leg: leg that represent the qubit.
        """
        return self.q_to_leg_and_node[q]

__init__(h)

Construct a stabilizer Tanner code tensor network.

Parameters:

Name Type Description Default
h ndarray

Parity check matrix in symplectic form (must have even number of columns).

required

Raises:

Type Description
ValueError

If the parity check matrix is not symplectic.

Source code in planqtn/networks/stabilizer_tanner_code.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(self, h: np.ndarray):
    """Construct a stabilizer Tanner code tensor network.

    Args:
        h: Parity check matrix in symplectic form (must have even number of columns).

    Raises:
        ValueError: If the parity check matrix is not symplectic.
    """
    self.parity_check_matrix = h
    if h.shape[1] % 2 == 1:
        raise ValueError(f"Not a symplectic matrix: {h}")

    r = h.shape[0]
    n = h.shape[1] // 2

    checks = []
    for i in range(r):
        weight = np.count_nonzero(h[i])
        check = StabilizerCodeTensorEnumerator(
            h=Legos.z_rep_code(weight + 2), tensor_id=f"check{i}"
        )
        check = check.trace_with_stopper(Legos.stopper_x, (f"check{i}", 0))
        check = check.trace_with_stopper(Legos.stopper_x, (f"check{i}", 1))
        checks.append(check)

    traces = []
    next_check_legs = [2] * r
    q_tensors = []
    self.q_to_leg_and_node: List[Tuple[TensorId, TensorLeg]] = []

    # for each qubit we create merged tensors across all checks
    for q in range(n):
        q_tensor = StabilizerCodeTensorEnumerator(
            h=Legos.stopper_i, tensor_id=f"q{q}"
        )
        physical_leg = (f"q{q}", 0)
        for i in range(r):
            op = tuple(h[i, (q, q + n)])
            if op == (0, 0):
                continue

            if op == (1, 0):
                q_tensor = q_tensor.conjoin(
                    StabilizerCodeTensorEnumerator(
                        h=Legos.x_rep_code(3), tensor_id=f"q{q}.c{i}"
                    ),
                    [physical_leg],
                    [0],
                )
                traces.append(
                    (
                        q_tensor.tensor_id,
                        checks[i].tensor_id,
                        [(f"q{q}.c{i}", 1)],
                        [next_check_legs[i]],
                    )
                )
                next_check_legs[i] += 1
                physical_leg = (f"q{q}.c{i}", 2)

            elif op == (0, 1):
                q_tensor = q_tensor.conjoin(
                    StabilizerCodeTensorEnumerator(
                        h=Legos.z_rep_code(3), tensor_id=f"q{q}.z{i}"
                    ),
                    [physical_leg],
                    [0],
                )
                q_tensor = q_tensor.conjoin(
                    StabilizerCodeTensorEnumerator(
                        h=Legos.h, tensor_id=f"q{q}.c{i}"
                    ),
                    [(f"q{q}.z{i}", 1)],
                    [0],
                )
                traces.append(
                    (
                        q_tensor.tensor_id,
                        checks[i].tensor_id,
                        [(f"q{q}.c{i}", 1)],
                        [next_check_legs[i]],
                    )
                )
                next_check_legs[i] += 1
                physical_leg = (f"q{q}.z{i}", 2)

            else:
                raise ValueError("Y stabilizer is not implemented yet...")
        q_tensors.append(q_tensor)
        self.q_to_leg_and_node.append((physical_leg[0], physical_leg))

    super().__init__(nodes={n.tensor_id: n for n in q_tensors + checks})

    for t in traces:
        self.self_trace(*t)

n_qubits()

Get the total number of qubits in the tensor network.

Returns:

Name Type Description
int int

Total number of qubits represented by this tensor network.

Source code in planqtn/networks/stabilizer_tanner_code.py
128
129
130
131
132
133
134
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns:
        int: Total number of qubits represented by this tensor network.
    """
    return int(self.parity_check_matrix.shape[1] // 2)

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

Returns the tensor and leg for the given qubit index.

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Name Type Description
TensorId

Node ID: node id for the tensor in the network

Leg TensorLeg

leg that represent the qubit.

Source code in planqtn/networks/stabilizer_tanner_code.py
136
137
138
139
140
141
142
143
144
145
146
147
148
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    Returns the tensor and leg for the given qubit index.

    Args:
        q: Global qubit index.

    Returns:
        Node ID: node id for the tensor in the network
        Leg: leg that represent the qubit.
    """
    return self.q_to_leg_and_node[q]

SurfaceCodeTN

Bases: TensorNetwork

A tensor network layout for the surface code.

Source code in planqtn/networks/surface_code.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class SurfaceCodeTN(TensorNetwork):
    """A tensor network layout for the surface code."""

    def __init__(
        self,
        d: int,
        lego: Callable[[TensorId], GF2] = lambda i: Legos.encoding_tensor_512,
        coset_error: Optional[GF2] = None,
        truncate_length: Optional[int] = None,
    ):
        """Construct a surface code tensor network.

        The numbering convention is as follows for the tensor ids (row, column):

        ```
        (0,0)  (0,2)  (0,4)
            (1,1)   (1,3)
        (2,0)  (2,2)  (2,4)
            (3,1)   (3,3)
        (4,0)  (4,2)  (4,4)
        ```
        The construction is based on the following work:

        Cao, ChunJun, Michael J. Gullans, Brad Lackey, and Zitao Wang. 2024.
        “Quantum Lego Expansion Pack: Enumerators from Tensor Networks.”
        PRX Quantum 5 (3): 030313. https://doi.org/10.1103/PRXQuantum.5.030313.

        Args:
            d: The number of qubits in the surface code.
            lego: The lego to use for the surface code.
            coset_error: The coset error to use for the surface code.
            truncate_length: The truncate length to use for the surface code.

        Raises:
            ValueError: If the distance is less than 2.
        """
        if d < 2:
            raise ValueError("Only d=2+ is supported.")

        # numbering convention:

        # (0,0)  (0,2)  (0,4)
        #    (1,1)   (1,3)
        # (2,0)  (2,2)  (2,4)
        #    (3,1)   (3,3)
        # (4,0)  (4,2)  (4,4)

        last_row = 2 * d - 2
        last_col = 2 * d - 2

        super().__init__(
            [
                StabilizerCodeTensorEnumerator(lego((r, c)), tensor_id=(r, c))
                for r in range(last_row + 1)
                for c in range(r % 2, last_col + 1, 2)
            ],
            truncate_length=truncate_length,
        )
        self._q_to_node = [
            (r, c) for r in range(last_row + 1) for c in range(r % 2, last_col + 1, 2)
        ]

        nodes = self.nodes

        # we take care of corners first

        nodes[(0, 0)] = (
            nodes[(0, 0)]
            .trace_with_stopper(Legos.stopper_z, 0)
            .trace_with_stopper(Legos.stopper_z, 1)
            .trace_with_stopper(Legos.stopper_x, 3)
        )
        nodes[(0, last_col)] = (
            nodes[(0, last_col)]
            .trace_with_stopper(Legos.stopper_z, 2)
            .trace_with_stopper(Legos.stopper_z, 3)
            .trace_with_stopper(Legos.stopper_x, 0)
        )
        nodes[(last_row, 0)] = (
            nodes[(last_row, 0)]
            .trace_with_stopper(Legos.stopper_z, 0)
            .trace_with_stopper(Legos.stopper_z, 1)
            .trace_with_stopper(Legos.stopper_x, 2)
        )
        nodes[(last_row, last_col)] = (
            nodes[(last_row, last_col)]
            .trace_with_stopper(Legos.stopper_z, 2)
            .trace_with_stopper(Legos.stopper_z, 3)
            .trace_with_stopper(Legos.stopper_x, 1)
        )

        for k in range(2, last_col, 2):
            # X boundaries on the top and bottom
            nodes[(0, k)] = (
                nodes[(0, k)]
                .trace_with_stopper(Legos.stopper_x, 0)
                .trace_with_stopper(Legos.stopper_x, 3)
            )
            nodes[(last_row, k)] = (
                nodes[(last_row, k)]
                .trace_with_stopper(Legos.stopper_x, 1)
                .trace_with_stopper(Legos.stopper_x, 2)
            )

            # Z boundaries on left and right
            nodes[(k, 0)] = (
                nodes[(k, 0)]
                .trace_with_stopper(Legos.stopper_z, 0)
                .trace_with_stopper(Legos.stopper_z, 1)
            )
            nodes[(k, last_col)] = (
                nodes[(k, last_col)]
                .trace_with_stopper(Legos.stopper_z, 2)
                .trace_with_stopper(Legos.stopper_z, 3)
            )

        # we'll trace diagonally
        for diag in range(1, last_row + 1):
            # connecting the middle to the previous diagonal's middle
            self.self_trace(
                (diag - 1, diag - 1),
                (diag, diag),
                [2 if diag % 2 == 1 else 1],
                [3 if diag % 2 == 1 else 0],
            )
            # go left until hitting the left column or the bottom row
            # and at the same time go right until hitting the right col or the top row (symmetric)
            row, col = diag + 1, diag - 1
            while row <= last_row and col >= 0:
                # going left
                self.self_trace(
                    (row - 1, col + 1),
                    (row, col),
                    [0 if row % 2 == 0 else 1],
                    [3 if row % 2 == 0 else 2],
                )

                # going right
                self.self_trace(
                    (col + 1, row - 1),
                    (col, row),
                    [3 if row % 2 == 1 else 2],
                    [0 if row % 2 == 1 else 1],
                )

                if row - 1 >= 0 and col - 1 >= 0:
                    # connect to previous diagonal
                    # on the left
                    self.self_trace(
                        (row - 1, col - 1),
                        (row, col),
                        [2 if row % 2 == 1 else 1],
                        [3 if row % 2 == 1 else 0],
                    )
                    # on the right
                    self.self_trace(
                        (col - 1, row - 1),
                        (col, row),
                        [2 if row % 2 == 1 else 1],
                        [3 if row % 2 == 1 else 0],
                    )

                row += 1
                col -= 1
            # go right until hitting the right column
        self.n = len(self.nodes)
        if coset_error is None:
            coset_error = GF2.Zeros(2 * self.n)
        self.set_coset(coset_error)

    def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
        """Map a qubit index to its corresponding node and leg.

        Returns the tensor and leg for the given qubit index. We follow row-major ordering, i.e. for
        this layout:
        ```
        (0,0)  (0,2)  (0,4)
            (1,1)   (1,3)
        (2,0)  (2,2)  (2,4)
            (3,1)   (3,3)
        (4,0)  (4,2)  (4,4)
        ```
        the qubits are ordered as follows:
        ```
        0  1  2
          3  4
        5  6  7
          8  9
        10 11 12
        ```

        Args:
            q: Global qubit index.

        Returns:
            Node ID: node id for the tensor in the network
            Leg: leg that represent the qubit.
        """
        return self._q_to_node[q], (self._q_to_node[q], 4)

    def n_qubits(self) -> int:
        """Get the total number of qubits in the tensor network.

        Returns:
            int: Total number of qubits represented by this tensor network.
        """
        return self.n

__init__(d, lego=lambda i: Legos.encoding_tensor_512, coset_error=None, truncate_length=None)

Construct a surface code tensor network.

The numbering convention is as follows for the tensor ids (row, column):

(0,0)  (0,2)  (0,4)
    (1,1)   (1,3)
(2,0)  (2,2)  (2,4)
    (3,1)   (3,3)
(4,0)  (4,2)  (4,4)
The construction is based on the following work:

Cao, ChunJun, Michael J. Gullans, Brad Lackey, and Zitao Wang. 2024. “Quantum Lego Expansion Pack: Enumerators from Tensor Networks.” PRX Quantum 5 (3): 030313. https://doi.org/10.1103/PRXQuantum.5.030313.

Parameters:

Name Type Description Default
d int

The number of qubits in the surface code.

required
lego Callable[[TensorId], GF2]

The lego to use for the surface code.

lambda i: encoding_tensor_512
coset_error Optional[GF2]

The coset error to use for the surface code.

None
truncate_length Optional[int]

The truncate length to use for the surface code.

None

Raises:

Type Description
ValueError

If the distance is less than 2.

Source code in planqtn/networks/surface_code.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __init__(
    self,
    d: int,
    lego: Callable[[TensorId], GF2] = lambda i: Legos.encoding_tensor_512,
    coset_error: Optional[GF2] = None,
    truncate_length: Optional[int] = None,
):
    """Construct a surface code tensor network.

    The numbering convention is as follows for the tensor ids (row, column):

    ```
    (0,0)  (0,2)  (0,4)
        (1,1)   (1,3)
    (2,0)  (2,2)  (2,4)
        (3,1)   (3,3)
    (4,0)  (4,2)  (4,4)
    ```
    The construction is based on the following work:

    Cao, ChunJun, Michael J. Gullans, Brad Lackey, and Zitao Wang. 2024.
    “Quantum Lego Expansion Pack: Enumerators from Tensor Networks.”
    PRX Quantum 5 (3): 030313. https://doi.org/10.1103/PRXQuantum.5.030313.

    Args:
        d: The number of qubits in the surface code.
        lego: The lego to use for the surface code.
        coset_error: The coset error to use for the surface code.
        truncate_length: The truncate length to use for the surface code.

    Raises:
        ValueError: If the distance is less than 2.
    """
    if d < 2:
        raise ValueError("Only d=2+ is supported.")

    # numbering convention:

    # (0,0)  (0,2)  (0,4)
    #    (1,1)   (1,3)
    # (2,0)  (2,2)  (2,4)
    #    (3,1)   (3,3)
    # (4,0)  (4,2)  (4,4)

    last_row = 2 * d - 2
    last_col = 2 * d - 2

    super().__init__(
        [
            StabilizerCodeTensorEnumerator(lego((r, c)), tensor_id=(r, c))
            for r in range(last_row + 1)
            for c in range(r % 2, last_col + 1, 2)
        ],
        truncate_length=truncate_length,
    )
    self._q_to_node = [
        (r, c) for r in range(last_row + 1) for c in range(r % 2, last_col + 1, 2)
    ]

    nodes = self.nodes

    # we take care of corners first

    nodes[(0, 0)] = (
        nodes[(0, 0)]
        .trace_with_stopper(Legos.stopper_z, 0)
        .trace_with_stopper(Legos.stopper_z, 1)
        .trace_with_stopper(Legos.stopper_x, 3)
    )
    nodes[(0, last_col)] = (
        nodes[(0, last_col)]
        .trace_with_stopper(Legos.stopper_z, 2)
        .trace_with_stopper(Legos.stopper_z, 3)
        .trace_with_stopper(Legos.stopper_x, 0)
    )
    nodes[(last_row, 0)] = (
        nodes[(last_row, 0)]
        .trace_with_stopper(Legos.stopper_z, 0)
        .trace_with_stopper(Legos.stopper_z, 1)
        .trace_with_stopper(Legos.stopper_x, 2)
    )
    nodes[(last_row, last_col)] = (
        nodes[(last_row, last_col)]
        .trace_with_stopper(Legos.stopper_z, 2)
        .trace_with_stopper(Legos.stopper_z, 3)
        .trace_with_stopper(Legos.stopper_x, 1)
    )

    for k in range(2, last_col, 2):
        # X boundaries on the top and bottom
        nodes[(0, k)] = (
            nodes[(0, k)]
            .trace_with_stopper(Legos.stopper_x, 0)
            .trace_with_stopper(Legos.stopper_x, 3)
        )
        nodes[(last_row, k)] = (
            nodes[(last_row, k)]
            .trace_with_stopper(Legos.stopper_x, 1)
            .trace_with_stopper(Legos.stopper_x, 2)
        )

        # Z boundaries on left and right
        nodes[(k, 0)] = (
            nodes[(k, 0)]
            .trace_with_stopper(Legos.stopper_z, 0)
            .trace_with_stopper(Legos.stopper_z, 1)
        )
        nodes[(k, last_col)] = (
            nodes[(k, last_col)]
            .trace_with_stopper(Legos.stopper_z, 2)
            .trace_with_stopper(Legos.stopper_z, 3)
        )

    # we'll trace diagonally
    for diag in range(1, last_row + 1):
        # connecting the middle to the previous diagonal's middle
        self.self_trace(
            (diag - 1, diag - 1),
            (diag, diag),
            [2 if diag % 2 == 1 else 1],
            [3 if diag % 2 == 1 else 0],
        )
        # go left until hitting the left column or the bottom row
        # and at the same time go right until hitting the right col or the top row (symmetric)
        row, col = diag + 1, diag - 1
        while row <= last_row and col >= 0:
            # going left
            self.self_trace(
                (row - 1, col + 1),
                (row, col),
                [0 if row % 2 == 0 else 1],
                [3 if row % 2 == 0 else 2],
            )

            # going right
            self.self_trace(
                (col + 1, row - 1),
                (col, row),
                [3 if row % 2 == 1 else 2],
                [0 if row % 2 == 1 else 1],
            )

            if row - 1 >= 0 and col - 1 >= 0:
                # connect to previous diagonal
                # on the left
                self.self_trace(
                    (row - 1, col - 1),
                    (row, col),
                    [2 if row % 2 == 1 else 1],
                    [3 if row % 2 == 1 else 0],
                )
                # on the right
                self.self_trace(
                    (col - 1, row - 1),
                    (col, row),
                    [2 if row % 2 == 1 else 1],
                    [3 if row % 2 == 1 else 0],
                )

            row += 1
            col -= 1
        # go right until hitting the right column
    self.n = len(self.nodes)
    if coset_error is None:
        coset_error = GF2.Zeros(2 * self.n)
    self.set_coset(coset_error)

n_qubits()

Get the total number of qubits in the tensor network.

Returns:

Name Type Description
int int

Total number of qubits represented by this tensor network.

Source code in planqtn/networks/surface_code.py
226
227
228
229
230
231
232
def n_qubits(self) -> int:
    """Get the total number of qubits in the tensor network.

    Returns:
        int: Total number of qubits represented by this tensor network.
    """
    return self.n

qubit_to_node_and_leg(q)

Map a qubit index to its corresponding node and leg.

Returns the tensor and leg for the given qubit index. We follow row-major ordering, i.e. for this layout:

(0,0)  (0,2)  (0,4)
    (1,1)   (1,3)
(2,0)  (2,2)  (2,4)
    (3,1)   (3,3)
(4,0)  (4,2)  (4,4)
the qubits are ordered as follows:
0  1  2
  3  4
5  6  7
  8  9
10 11 12

Parameters:

Name Type Description Default
q int

Global qubit index.

required

Returns:

Name Type Description
TensorId

Node ID: node id for the tensor in the network

Leg TensorLeg

leg that represent the qubit.

Source code in planqtn/networks/surface_code.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def qubit_to_node_and_leg(self, q: int) -> Tuple[TensorId, TensorLeg]:
    """Map a qubit index to its corresponding node and leg.

    Returns the tensor and leg for the given qubit index. We follow row-major ordering, i.e. for
    this layout:
    ```
    (0,0)  (0,2)  (0,4)
        (1,1)   (1,3)
    (2,0)  (2,2)  (2,4)
        (3,1)   (3,3)
    (4,0)  (4,2)  (4,4)
    ```
    the qubits are ordered as follows:
    ```
    0  1  2
      3  4
    5  6  7
      8  9
    10 11 12
    ```

    Args:
        q: Global qubit index.

    Returns:
        Node ID: node id for the tensor in the network
        Leg: leg that represent the qubit.
    """
    return self._q_to_node[q], (self._q_to_node[q], 4)

Utilities

The planqtn.progress_reporter package

Progress reporter interface and implementations for calculations.

The main class is ProgressReporter which is an abstract base class for all progress reporters.

The main methods are:

  • iterate: Iterates over an iterable and reports progress on every item.
  • enter_phase: Starts a new phase.
  • exit_phase: Ends the current phase.

The main implementations are:

This is the main mechanism for reporting progress back to PlanqTN Studio UI from the backend jobs in realtime.

DummyProgressReporter

Bases: ProgressReporter

A no-op progress reporter that does nothing.

This implementation provides a null progress reporter that can be used when progress reporting is not needed. It implements all required methods but performs no actual reporting, making it useful as a default or for testing purposes or creating a silent mode for scripts to run.

Source code in planqtn/progress_reporter.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class DummyProgressReporter(ProgressReporter):
    """A no-op progress reporter that does nothing.

    This implementation provides a null progress reporter that can be used
    when progress reporting is not needed. It implements all required methods
    but performs no actual reporting, making it useful as a default or for
    testing purposes or creating a silent mode for scripts to run.
    """

    def handle_result(self, result: Dict[str, Any]) -> None:
        """Handle progress result (no-op for dummy reporter).

        The dummy reporter ignores all progress results, making it useful
        when progress reporting is not needed.

        Args:
            result: Progress result dictionary (ignored).
        """

handle_result(result)

Handle progress result (no-op for dummy reporter).

The dummy reporter ignores all progress results, making it useful when progress reporting is not needed.

Parameters:

Name Type Description Default
result Dict[str, Any]

Progress result dictionary (ignored).

required
Source code in planqtn/progress_reporter.py
393
394
395
396
397
398
399
400
401
def handle_result(self, result: Dict[str, Any]) -> None:
    """Handle progress result (no-op for dummy reporter).

    The dummy reporter ignores all progress results, making it useful
    when progress reporting is not needed.

    Args:
        result: Progress result dictionary (ignored).
    """

IterationState

State tracking information for a single iteration phase.

This class tracks the progress and timing information for a single iteration or calculation phase. It maintains statistics like current progress, timing, and performance metrics that can be used for progress reporting and analysis.

Attributes:

Name Type Description
desc str

Description of the current iteration phase.

total_size int

Total number of items to process in this iteration.

current_item int

Current item being processed (0-indexed).

start_time float

Timestamp when the iteration started.

end_time float | None

Timestamp when the iteration ended (None if not finished).

duration float | None

Total duration of the iteration in seconds (None if not finished).

avg_time_per_item float | None

Average time per item in seconds (None if no items processed).

Source code in planqtn/progress_reporter.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@attr.s
class IterationState:
    """State tracking information for a single iteration phase.

    This class tracks the progress and timing information for a single iteration
    or calculation phase. It maintains statistics like current progress, timing,
    and performance metrics that can be used for progress reporting and analysis.

    Attributes:
        desc: Description of the current iteration phase.
        total_size: Total number of items to process in this iteration.
        current_item: Current item being processed (0-indexed).
        start_time: Timestamp when the iteration started.
        end_time: Timestamp when the iteration ended (None if not finished).
        duration: Total duration of the iteration in seconds (None if not finished).
        avg_time_per_item: Average time per item in seconds (None if no items processed).
    """

    desc: str = attr.ib()
    total_size: int = attr.ib()
    current_item: int = attr.ib(default=0)
    start_time: float = attr.ib(default=time.time())
    end_time: float | None = attr.ib(default=None)
    duration: float | None = attr.ib(default=None)
    avg_time_per_item: float | None = attr.ib(default=None)

    def update(self, current_item: int | None = None) -> None:
        """Update the iteration state with progress information.

        Updates the current item count, recalculates duration, and updates
        the average time per item. If no current_item is provided, increments
        the current item by 1.

        Args:
            current_item: New current item index. If None, increments by 1.
        """
        if current_item is None:
            current_item = self.current_item + 1
        self.current_item = current_item
        self.duration = time.time() - self.start_time
        self._update_avg_time_per_item()

    def _update_avg_time_per_item(self) -> None:
        if self.current_item == 0:
            self.avg_time_per_item = None
        elif self.current_item is not None and self.duration is not None:
            self.avg_time_per_item = self.duration / self.current_item

    def end(self) -> None:
        """Mark the iteration as completed.

        Sets the end time and calculates the final duration and average time
        per item statistics.
        """
        self.end_time = time.time()
        self.duration = self.end_time - self.start_time
        self._update_avg_time_per_item()

    def __repr__(self) -> str:
        return (
            f"Iteration(desc={self.desc}, current_item={self.current_item}, "
            f"total_size={self.total_size}, duration={self.duration}, "
            f"avg_time_per_item={self.avg_time_per_item}), "
            f"start_time={self.start_time}, end_time={self.end_time}"
        )

    def to_dict(self) -> Dict[str, Any]:
        """Convert the IterationState to a dictionary for JSON serialization.

        Returns:
            Dict[str, Any]: Dictionary representation of the iteration state
                suitable for JSON serialization.
        """
        return {
            "desc": self.desc,
            "total_size": self.total_size,
            "current_item": self.current_item,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "duration": self.duration,
            "avg_time_per_item": self.avg_time_per_item,
        }

end()

Mark the iteration as completed.

Sets the end time and calculates the final duration and average time per item statistics.

Source code in planqtn/progress_reporter.py
83
84
85
86
87
88
89
90
91
def end(self) -> None:
    """Mark the iteration as completed.

    Sets the end time and calculates the final duration and average time
    per item statistics.
    """
    self.end_time = time.time()
    self.duration = self.end_time - self.start_time
    self._update_avg_time_per_item()

to_dict()

Convert the IterationState to a dictionary for JSON serialization.

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Dictionary representation of the iteration state suitable for JSON serialization.

Source code in planqtn/progress_reporter.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def to_dict(self) -> Dict[str, Any]:
    """Convert the IterationState to a dictionary for JSON serialization.

    Returns:
        Dict[str, Any]: Dictionary representation of the iteration state
            suitable for JSON serialization.
    """
    return {
        "desc": self.desc,
        "total_size": self.total_size,
        "current_item": self.current_item,
        "start_time": self.start_time,
        "end_time": self.end_time,
        "duration": self.duration,
        "avg_time_per_item": self.avg_time_per_item,
    }

update(current_item=None)

Update the iteration state with progress information.

Updates the current item count, recalculates duration, and updates the average time per item. If no current_item is provided, increments the current item by 1.

Parameters:

Name Type Description Default
current_item int | None

New current item index. If None, increments by 1.

None
Source code in planqtn/progress_reporter.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def update(self, current_item: int | None = None) -> None:
    """Update the iteration state with progress information.

    Updates the current item count, recalculates duration, and updates
    the average time per item. If no current_item is provided, increments
    the current item by 1.

    Args:
        current_item: New current item index. If None, increments by 1.
    """
    if current_item is None:
        current_item = self.current_item + 1
    self.current_item = current_item
    self.duration = time.time() - self.start_time
    self._update_avg_time_per_item()

IterationStateEncoder

Bases: JSONEncoder

Custom JSON encoder for IterationState objects.

This encoder extends the standard JSON encoder to handle IterationState objects by converting them to dictionaries using their to_dict() method. This enables JSON serialization of progress reporting data.

Source code in planqtn/progress_reporter.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class IterationStateEncoder(json.JSONEncoder):
    """Custom JSON encoder for IterationState objects.

    This encoder extends the standard JSON encoder to handle IterationState
    objects by converting them to dictionaries using their to_dict() method.
    This enables JSON serialization of progress reporting data.
    """

    def default(self, o: Any) -> Any:
        """Convert IterationState objects to dictionaries for JSON serialization.

        Args:
            o: Object to encode.

        Returns:
            Any: Dictionary representation if o is an IterationState,
                 otherwise delegates to parent class.
        """
        if isinstance(o, IterationState):
            return o.to_dict()
        return super().default(o)

    def __call__(self, o: Any) -> str:
        """Encode an object to JSON string.

        Args:
            o: Object to encode.

        Returns:
            str: JSON string representation of the object.
        """
        return self.encode(o)

__call__(o)

Encode an object to JSON string.

Parameters:

Name Type Description Default
o Any

Object to encode.

required

Returns:

Name Type Description
str str

JSON string representation of the object.

Source code in planqtn/progress_reporter.py
141
142
143
144
145
146
147
148
149
150
def __call__(self, o: Any) -> str:
    """Encode an object to JSON string.

    Args:
        o: Object to encode.

    Returns:
        str: JSON string representation of the object.
    """
    return self.encode(o)

default(o)

Convert IterationState objects to dictionaries for JSON serialization.

Parameters:

Name Type Description Default
o Any

Object to encode.

required

Returns:

Name Type Description
Any Any

Dictionary representation if o is an IterationState, otherwise delegates to parent class.

Source code in planqtn/progress_reporter.py
127
128
129
130
131
132
133
134
135
136
137
138
139
def default(self, o: Any) -> Any:
    """Convert IterationState objects to dictionaries for JSON serialization.

    Args:
        o: Object to encode.

    Returns:
        Any: Dictionary representation if o is an IterationState,
             otherwise delegates to parent class.
    """
    if isinstance(o, IterationState):
        return o.to_dict()
    return super().default(o)

ProgressReporter

Bases: ABC

Abstract base class for progress reporting in calculations.

This class provides a framework for reporting progress during long-running calculations. It supports nested iteration phases and can be composed with other progress reporters. The main mechanism for reporting progress back to PlanqTN Studio UI from backend jobs in realtime.

Subclasses should implement the handle_result method to define how progress information is processed (e.g., displayed, logged, or sent to a UI).

Attributes:

Name Type Description
sub_reporter

Optional nested progress reporter for composition.

iterator_stack list[IterationState]

Stack of active iteration states for nested phases.

iteration_report_frequency

Minimum time interval between progress reports.

Source code in planqtn/progress_reporter.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
class ProgressReporter(abc.ABC):
    """Abstract base class for progress reporting in calculations.

    This class provides a framework for reporting progress during long-running
    calculations. It supports nested iteration phases and can be composed with
    other progress reporters. The main mechanism for reporting progress back to
    PlanqTN Studio UI from backend jobs in realtime.

    Subclasses should implement the `handle_result` method to define how progress
    information is processed (e.g., displayed, logged, or sent to a UI).

    Attributes:
        sub_reporter: Optional nested progress reporter for composition.
        iterator_stack: Stack of active iteration states for nested phases.
        iteration_report_frequency: Minimum time interval between progress reports.
    """

    def __init__(
        self,
        sub_reporter: Optional["ProgressReporter"] = None,
        iteration_report_frequency: float = 0.0,
    ):
        """Initialize the progress reporter.

        Args:
            sub_reporter: Optional nested progress reporter for composition.
            iteration_report_frequency: Minimum time interval between progress
                reports in seconds. If 0.0, reports on every iteration.
        """
        self.sub_reporter = sub_reporter
        self.iterator_stack: list[IterationState] = []
        self.iteration_report_frequency = iteration_report_frequency

    def __enter__(self) -> "ProgressReporter":
        return self

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        pass

    @abc.abstractmethod
    def handle_result(self, result: Dict[str, Any]) -> None:
        """Handle progress result data.

        This hook method must be implemented by subclasses to define how progress
        information is processed. The result dictionary contains iteration state
        and metadata about the current progress.

        Args:
            result: Dictionary containing progress information including:
                - iteration: IterationState object or dict
                - level: Nesting level of the current iteration
                - Additional metadata specific to the implementation
        """

    def log_result(self, result: Dict[str, Any]) -> None:
        """Log progress result and propagate to sub-reporter.

        Converts IterationState objects to dictionaries for serialization,
        calls the handle_result method, and propagates the result to any
        nested sub-reporter.

        Args:
            result: Dictionary containing progress information.
        """
        # Convert IterationState to dict in the result
        serializable_result = {}
        for key, value in result.items():
            if isinstance(value, IterationState):
                serializable_result[key] = value.to_dict()
            else:
                serializable_result[key] = value

        self.handle_result(serializable_result)
        if self.sub_reporter is not None:
            self.sub_reporter.log_result(serializable_result)

    def iterate(
        self, iterable: Iterable, desc: str, total_size: int
    ) -> Generator[Any, None, None]:
        """Start a new iteration phase with progress reporting.

        Creates a generator that yields items from the iterable while tracking
        progress and reporting it at regular intervals. The iteration state is
        maintained on a stack to support nested iterations.

        Args:
            iterable: The iterable to iterate over.
            desc: Description of the iteration phase.
            total_size: Total number of items to process.

        Yields:
            Items from the iterable.
        """
        bottom_iterator_state = IterationState(
            desc, start_time=time.time(), total_size=total_size
        )
        self.iterator_stack.append(bottom_iterator_state)

        if self.sub_reporter is not None:
            iterable = self.sub_reporter.iterate(iterable, desc, total_size)
        time_last_report = time.time()
        for item in iterable:
            yield item
            bottom_iterator_state.update()
            if time.time() - time_last_report > self.iteration_report_frequency:
                time_last_report = time.time()

                self.log_result(
                    {
                        "iteration": bottom_iterator_state,
                        "level": len(self.iterator_stack),
                    }
                )
            # higher level iterators need to be updated, this is just
            # a hack to ensure that the timestamps and avg time per item
            # is updated for all iterators
            for higher_iterator in self.iterator_stack[:-1]:
                higher_iterator.update(higher_iterator.current_item)
            # print(
            #     f"{type(self)}: iteration_state {bottom_iterator_state} iterated! {item}"
            # )

        bottom_iterator_state.end()
        self.log_result(
            {"iteration": bottom_iterator_state, "level": len(self.iterator_stack)}
        )
        self.iterator_stack.pop()

    def enter_phase(self, desc: str) -> _GeneratorContextManager[Any, None, None]:
        """Enter a new calculation phase with progress tracking.

        Creates a context manager for tracking a single-step phase. This is
        useful for marking the beginning and end of calculation phases that
        don't involve iteration but should still be tracked for progress reporting.

        Args:
            desc: Description of the phase.

        Returns:
            Context manager that can be used with 'with' statement.
        """

        @contextlib.contextmanager
        def phase_iterator() -> Generator[Any, None, None]:
            yield from self.iterate(["item"], desc, total_size=1)

        return phase_iterator()

    def exit_phase(self) -> None:
        """Exit the current calculation phase.

        Removes the current iteration state from the stack, effectively
        ending the current phase. This is typically called automatically
        when using the context manager from enter_phase().
        """
        self.iterator_stack.pop()

__init__(sub_reporter=None, iteration_report_frequency=0.0)

Initialize the progress reporter.

Parameters:

Name Type Description Default
sub_reporter Optional[ProgressReporter]

Optional nested progress reporter for composition.

None
iteration_report_frequency float

Minimum time interval between progress reports in seconds. If 0.0, reports on every iteration.

0.0
Source code in planqtn/progress_reporter.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def __init__(
    self,
    sub_reporter: Optional["ProgressReporter"] = None,
    iteration_report_frequency: float = 0.0,
):
    """Initialize the progress reporter.

    Args:
        sub_reporter: Optional nested progress reporter for composition.
        iteration_report_frequency: Minimum time interval between progress
            reports in seconds. If 0.0, reports on every iteration.
    """
    self.sub_reporter = sub_reporter
    self.iterator_stack: list[IterationState] = []
    self.iteration_report_frequency = iteration_report_frequency

enter_phase(desc)

Enter a new calculation phase with progress tracking.

Creates a context manager for tracking a single-step phase. This is useful for marking the beginning and end of calculation phases that don't involve iteration but should still be tracked for progress reporting.

Parameters:

Name Type Description Default
desc str

Description of the phase.

required

Returns:

Type Description
_GeneratorContextManager[Any, None, None]

Context manager that can be used with 'with' statement.

Source code in planqtn/progress_reporter.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def enter_phase(self, desc: str) -> _GeneratorContextManager[Any, None, None]:
    """Enter a new calculation phase with progress tracking.

    Creates a context manager for tracking a single-step phase. This is
    useful for marking the beginning and end of calculation phases that
    don't involve iteration but should still be tracked for progress reporting.

    Args:
        desc: Description of the phase.

    Returns:
        Context manager that can be used with 'with' statement.
    """

    @contextlib.contextmanager
    def phase_iterator() -> Generator[Any, None, None]:
        yield from self.iterate(["item"], desc, total_size=1)

    return phase_iterator()

exit_phase()

Exit the current calculation phase.

Removes the current iteration state from the stack, effectively ending the current phase. This is typically called automatically when using the context manager from enter_phase().

Source code in planqtn/progress_reporter.py
301
302
303
304
305
306
307
308
def exit_phase(self) -> None:
    """Exit the current calculation phase.

    Removes the current iteration state from the stack, effectively
    ending the current phase. This is typically called automatically
    when using the context manager from enter_phase().
    """
    self.iterator_stack.pop()

handle_result(result) abstractmethod

Handle progress result data.

This hook method must be implemented by subclasses to define how progress information is processed. The result dictionary contains iteration state and metadata about the current progress.

Parameters:

Name Type Description Default
result Dict[str, Any]

Dictionary containing progress information including: - iteration: IterationState object or dict - level: Nesting level of the current iteration - Additional metadata specific to the implementation

required
Source code in planqtn/progress_reporter.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@abc.abstractmethod
def handle_result(self, result: Dict[str, Any]) -> None:
    """Handle progress result data.

    This hook method must be implemented by subclasses to define how progress
    information is processed. The result dictionary contains iteration state
    and metadata about the current progress.

    Args:
        result: Dictionary containing progress information including:
            - iteration: IterationState object or dict
            - level: Nesting level of the current iteration
            - Additional metadata specific to the implementation
    """

iterate(iterable, desc, total_size)

Start a new iteration phase with progress reporting.

Creates a generator that yields items from the iterable while tracking progress and reporting it at regular intervals. The iteration state is maintained on a stack to support nested iterations.

Parameters:

Name Type Description Default
iterable Iterable

The iterable to iterate over.

required
desc str

Description of the iteration phase.

required
total_size int

Total number of items to process.

required

Yields:

Type Description
Any

Items from the iterable.

Source code in planqtn/progress_reporter.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def iterate(
    self, iterable: Iterable, desc: str, total_size: int
) -> Generator[Any, None, None]:
    """Start a new iteration phase with progress reporting.

    Creates a generator that yields items from the iterable while tracking
    progress and reporting it at regular intervals. The iteration state is
    maintained on a stack to support nested iterations.

    Args:
        iterable: The iterable to iterate over.
        desc: Description of the iteration phase.
        total_size: Total number of items to process.

    Yields:
        Items from the iterable.
    """
    bottom_iterator_state = IterationState(
        desc, start_time=time.time(), total_size=total_size
    )
    self.iterator_stack.append(bottom_iterator_state)

    if self.sub_reporter is not None:
        iterable = self.sub_reporter.iterate(iterable, desc, total_size)
    time_last_report = time.time()
    for item in iterable:
        yield item
        bottom_iterator_state.update()
        if time.time() - time_last_report > self.iteration_report_frequency:
            time_last_report = time.time()

            self.log_result(
                {
                    "iteration": bottom_iterator_state,
                    "level": len(self.iterator_stack),
                }
            )
        # higher level iterators need to be updated, this is just
        # a hack to ensure that the timestamps and avg time per item
        # is updated for all iterators
        for higher_iterator in self.iterator_stack[:-1]:
            higher_iterator.update(higher_iterator.current_item)
        # print(
        #     f"{type(self)}: iteration_state {bottom_iterator_state} iterated! {item}"
        # )

    bottom_iterator_state.end()
    self.log_result(
        {"iteration": bottom_iterator_state, "level": len(self.iterator_stack)}
    )
    self.iterator_stack.pop()

log_result(result)

Log progress result and propagate to sub-reporter.

Converts IterationState objects to dictionaries for serialization, calls the handle_result method, and propagates the result to any nested sub-reporter.

Parameters:

Name Type Description Default
result Dict[str, Any]

Dictionary containing progress information.

required
Source code in planqtn/progress_reporter.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def log_result(self, result: Dict[str, Any]) -> None:
    """Log progress result and propagate to sub-reporter.

    Converts IterationState objects to dictionaries for serialization,
    calls the handle_result method, and propagates the result to any
    nested sub-reporter.

    Args:
        result: Dictionary containing progress information.
    """
    # Convert IterationState to dict in the result
    serializable_result = {}
    for key, value in result.items():
        if isinstance(value, IterationState):
            serializable_result[key] = value.to_dict()
        else:
            serializable_result[key] = value

    self.handle_result(serializable_result)
    if self.sub_reporter is not None:
        self.sub_reporter.log_result(serializable_result)

TqdmProgressReporter

Bases: ProgressReporter

Progress reporter that displays progress using tqdm progress bars.

This implementation uses the tqdm library to display progress bars in the terminal. It's useful for command-line applications and provides visual feedback during long-running calculations.

Attributes:

Name Type Description
file

Output stream for the progress bars (default: sys.stdout).

mininterval

Minimum time interval between progress bar updates.

Source code in planqtn/progress_reporter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class TqdmProgressReporter(ProgressReporter):
    """Progress reporter that displays progress using `tqdm` progress bars.

    This implementation uses the `tqdm` library to display progress bars in the
    terminal. It's useful for command-line applications and provides visual
    feedback during long-running calculations.

    Attributes:
        file: Output stream for the progress bars (default: sys.stdout).
        mininterval: Minimum time interval between progress bar updates.
    """

    def __init__(
        self,
        file: TextIO = sys.stdout,
        mininterval: float | None = None,
        sub_reporter: Optional["ProgressReporter"] = None,
    ):
        """Initialize the `tqdm` progress reporter.

        Args:
            file: Output stream for progress bars (default: sys.stdout).
            mininterval: Minimum time interval between updates in seconds.
                If None, uses 2 seconds for large iterations (>100k items)
                or 0.1 seconds for smaller ones.
            sub_reporter: Optional nested progress reporter for composition.
        """
        super().__init__(sub_reporter)
        self.file = file
        self.mininterval = mininterval

    def iterate(
        self, iterable: Iterable, desc: str, total_size: int
    ) -> Generator[Any, None, None]:
        """Iterate with `tqdm` progress bar display.

        Overrides the parent iterate method to wrap the iteration with a `tqdm`
        progress bar that provides visual feedback in the terminal.

        Args:
            iterable: The iterable to iterate over.
            desc: Description for the progress bar.
            total_size: Total number of items to process.

        Yields:
            Items from the iterable.
        """
        t = tqdm(
            desc=desc,
            total=total_size,
            iterable=super().iterate(iterable, desc, total_size),
            file=self.file,
            # leave=False,
            mininterval=(
                self.mininterval
                if self.mininterval is not None
                else 2 if total_size > 1e5 else 0.1
            ),
        )
        yield from t
        t.close()

    def handle_result(self, result: Dict[str, Any]) -> None:
        """Handle progress result (no-op for `tqdm` reporter).

        The `tqdm` reporter doesn't need to handle results separately since
        the progress is displayed through the `tqdm` progress bar.

        Args:
            result: Progress result dictionary (ignored).
        """

__init__(file=sys.stdout, mininterval=None, sub_reporter=None)

Initialize the tqdm progress reporter.

Parameters:

Name Type Description Default
file TextIO

Output stream for progress bars (default: sys.stdout).

stdout
mininterval float | None

Minimum time interval between updates in seconds. If None, uses 2 seconds for large iterations (>100k items) or 0.1 seconds for smaller ones.

None
sub_reporter Optional[ProgressReporter]

Optional nested progress reporter for composition.

None
Source code in planqtn/progress_reporter.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def __init__(
    self,
    file: TextIO = sys.stdout,
    mininterval: float | None = None,
    sub_reporter: Optional["ProgressReporter"] = None,
):
    """Initialize the `tqdm` progress reporter.

    Args:
        file: Output stream for progress bars (default: sys.stdout).
        mininterval: Minimum time interval between updates in seconds.
            If None, uses 2 seconds for large iterations (>100k items)
            or 0.1 seconds for smaller ones.
        sub_reporter: Optional nested progress reporter for composition.
    """
    super().__init__(sub_reporter)
    self.file = file
    self.mininterval = mininterval

handle_result(result)

Handle progress result (no-op for tqdm reporter).

The tqdm reporter doesn't need to handle results separately since the progress is displayed through the tqdm progress bar.

Parameters:

Name Type Description Default
result Dict[str, Any]

Progress result dictionary (ignored).

required
Source code in planqtn/progress_reporter.py
373
374
375
376
377
378
379
380
381
def handle_result(self, result: Dict[str, Any]) -> None:
    """Handle progress result (no-op for `tqdm` reporter).

    The `tqdm` reporter doesn't need to handle results separately since
    the progress is displayed through the `tqdm` progress bar.

    Args:
        result: Progress result dictionary (ignored).
    """

iterate(iterable, desc, total_size)

Iterate with tqdm progress bar display.

Overrides the parent iterate method to wrap the iteration with a tqdm progress bar that provides visual feedback in the terminal.

Parameters:

Name Type Description Default
iterable Iterable

The iterable to iterate over.

required
desc str

Description for the progress bar.

required
total_size int

Total number of items to process.

required

Yields:

Type Description
Any

Items from the iterable.

Source code in planqtn/progress_reporter.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def iterate(
    self, iterable: Iterable, desc: str, total_size: int
) -> Generator[Any, None, None]:
    """Iterate with `tqdm` progress bar display.

    Overrides the parent iterate method to wrap the iteration with a `tqdm`
    progress bar that provides visual feedback in the terminal.

    Args:
        iterable: The iterable to iterate over.
        desc: Description for the progress bar.
        total_size: Total number of items to process.

    Yields:
        Items from the iterable.
    """
    t = tqdm(
        desc=desc,
        total=total_size,
        iterable=super().iterate(iterable, desc, total_size),
        file=self.file,
        # leave=False,
        mininterval=(
            self.mininterval
            if self.mininterval is not None
            else 2 if total_size > 1e5 else 0.1
        ),
    )
    yield from t
    t.close()

The planqtn.symplectic package

Symplectic operations and utilities.

omega(n)

Create a symplectic operator for the omega matrix over GF(2).

For n the omega matrix is: [0 I_n] [I_n 0]

where I_n is the n x n identity matrix.

Parameters:

Name Type Description Default
n int

The number of qubits.

required

Returns:

Type Description
GF2

The symplectic operator for the omega matrix.

Source code in planqtn/symplectic.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def omega(n: int) -> GF2:
    """Create a symplectic operator for the omega matrix over GF(2).

    For n the omega matrix is:
    [0 `I_n`]
    [`I_n` 0]

    where `I_n` is the n x n identity matrix.

    Args:
        n: The number of qubits.

    Returns:
        The symplectic operator for the omega matrix.
    """
    return GF2(
        np.block(
            [
                [GF2.Zeros((n, n)), GF2.Identity(n)],
                [GF2.Identity(n), GF2.Zeros((n, n))],
            ]
        )
    )

replace_with_op_on_indices(indices, op, target)

Replace target symplectic operator's operations with op on the given indices.

Parameters:

Name Type Description Default
indices List[int]

Indices to replace on.

required
op GF2

The operator to replace with.

required
target GF2

The target operator.

required

Returns:

Type Description
GF2

The replaced operator.

Source code in planqtn/symplectic.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def replace_with_op_on_indices(indices: List[int], op: GF2, target: GF2) -> GF2:
    """Replace target symplectic operator's operations with op on the given indices.

    Args:
        indices: Indices to replace on.
        op: The operator to replace with.
        target: The target operator.

    Returns:
        The replaced operator.
    """
    m = len(indices)
    n = len(op) // 2

    res = target.copy()
    res[indices] = op[:m]
    res[np.array(indices) + n] = op[m:]
    return res

sconcat(*ops)

Concatenate symplectic operators.

Parameters:

Name Type Description Default
*ops Tuple[int, ...]

The symplectic operators to concatenate.

()

Returns:

Type Description
Tuple[int, ...]

The concatenated symplectic operator.

Source code in planqtn/symplectic.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def sconcat(*ops: Tuple[int, ...]) -> Tuple[int, ...]:
    """Concatenate symplectic operators.

    Args:
        *ops: The symplectic operators to concatenate.

    Returns:
        The concatenated symplectic operator.
    """
    ns = [len(op) // 2 for op in ops]
    return tuple(
        np.hstack(
            [  # X part
                np.concatenate([op[:n] for n, op in zip(ns, ops)]).astype(np.int8),
                # Z part
                np.concatenate([op[n:] for n, op in zip(ns, ops)]).astype(np.int8),
            ],
        ).astype(np.int8)
    )

sprint(h, end='\n')

Print a symplectic matrix in string format.

Prints the string representation of the symplectic matrix to stdout.

Parameters:

Name Type Description Default
h GF2

Parity check matrix in GF2.

required
end str

String to append at the end (default: newline).

'\n'
Source code in planqtn/symplectic.py
179
180
181
182
183
184
185
186
187
188
def sprint(h: GF2, end: str = "\n") -> None:
    """Print a symplectic matrix in string format.

    Prints the string representation of the symplectic matrix to stdout.

    Args:
        h: Parity check matrix in GF2.
        end: String to append at the end (default: newline).
    """
    print(sstr(h), end=end)

sslice(op, indices)

Slice a symplectic operator.

Parameters:

Name Type Description Default
op GF2

The symplectic operator.

required
indices List[int] | slice | ndarray

The indices to slice.

required

Returns:

Type Description
GF2

The sliced symplectic operator.

Raises:

Type Description
ValueError

If the indices are of invalid type (neither list, np.ndarray, or slice).

Source code in planqtn/symplectic.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def sslice(op: GF2, indices: List[int] | slice | np.ndarray) -> GF2:
    """Slice a symplectic operator.

    Args:
        op: The symplectic operator.
        indices: The indices to slice.

    Returns:
        The sliced symplectic operator.

    Raises:
        ValueError: If the indices are of invalid type (neither list, np.ndarray, or slice).
    """
    n = len(op) // 2

    if isinstance(indices, list | np.ndarray):
        if len(indices) == 0:
            return GF2([])
        indices = np.array(indices)
        return GF2(np.concatenate([op[indices], op[indices + n]]))

    if isinstance(indices, slice):
        x = slice(
            0 if indices.start is None else indices.start,
            n if indices.stop is None else indices.stop,
        )

        z = slice(x.start + n, x.stop + n)
        return GF2(np.concatenate([op[x], op[z]]))

    raise ValueError(f"Invalid indices: {indices}")

sstr(h)

Convert a symplectic matrix to a string representation.

Creates a human-readable string representation of a symplectic matrix where X and Z parts are separated by a '|' character. Uses '_' for 0 and '1' for 1 to make the pattern more visible.

Parameters:

Name Type Description Default
h GF2

Parity check matrix in GF2.

required

Returns:

Name Type Description
str str

String representation of the matrix.

Source code in planqtn/symplectic.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def sstr(h: GF2) -> str:
    """Convert a symplectic matrix to a string representation.

    Creates a human-readable string representation of a symplectic matrix
    where X and Z parts are separated by a '|' character. Uses '_' for 0
    and '1' for 1 to make the pattern more visible.

    Args:
        h: Parity check matrix in GF2.

    Returns:
        str: String representation of the matrix.
    """
    n = h.shape[1] // 2

    return "\n".join(
        "".join("_1"[int(b)] for b in row[:n])
        + "|"
        + "".join("_1"[int(b)] for b in row[n:])
        for row in h
    )

symp_to_str(vec, swapxz=False)

Convert a symplectic operator to a string.

Parameters:

Name Type Description Default
vec GF2

The symplectic operator.

required
swapxz bool

Whether to swap X and Z.

False

Returns:

Type Description
str

The string representation of the symplectic operator.

Source code in planqtn/symplectic.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def symp_to_str(vec: GF2, swapxz: bool = False) -> str:
    """Convert a symplectic operator to a string.

    Args:
        vec: The symplectic operator.
        swapxz: Whether to swap X and Z.

    Returns:
        The string representation of the symplectic operator.
    """
    p = ["I", "X", "Z", "Y"]
    if swapxz:
        p = ["I", "Z", "X", "Y"]
    n = len(vec) // 2

    return "".join([p[2 * int(vec[i + n]) + int(vec[i])] for i in range(n)])

sympl_to_pauli_repr(op)

Convert a symplectic operator to a Pauli operator representation.

Parameters:

Name Type Description Default
op GF2

The symplectic operator.

required

Returns:

Type Description
Tuple[int, ...]

The Pauli operator representation of the symplectic operator.

Source code in planqtn/symplectic.py
69
70
71
72
73
74
75
76
77
78
79
def sympl_to_pauli_repr(op: GF2) -> Tuple[int, ...]:
    """Convert a symplectic operator to a Pauli operator representation.

    Args:
        op: The symplectic operator.

    Returns:
        The Pauli operator representation of the symplectic operator.
    """
    n = len(op) // 2
    return tuple(2 * int(op[i + n]) + int(op[i]) for i in range(n))

weight(op, skip_indices=())

Calculate the weight of a symplectic operator.

Parameters:

Name Type Description Default
op GF2

The symplectic operator.

required
skip_indices Sequence[int]

Indices to skip.

()

Returns:

Type Description
int

The weight of the symplectic operator.

Source code in planqtn/symplectic.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def weight(op: GF2, skip_indices: Sequence[int] = ()) -> int:
    """Calculate the weight of a symplectic operator.

    Args:
        op: The symplectic operator.
        skip_indices: Indices to skip.

    Returns:
        The weight of the symplectic operator.
    """
    n = len(op) // 2
    x_inds = np.array([i for i in range(n) if i not in skip_indices])
    z_inds = x_inds + n
    if len(x_inds) == 0 and len(z_inds) == 0:
        return 0
    return np.count_nonzero(op[x_inds] | op[z_inds])

The planqtn.linalg package

Linear algebra utilities.

gauss(mx, noswaps=False, col_subset=None)

Perform Gauss elimination on a GF2 matrix.

Performs row reduction on a GF2 matrix to bring it to row echelon form. Optionally can restrict elimination to a subset of columns and control whether row swaps are preserved.

Parameters:

Name Type Description Default
mx GF2

Input GF2 matrix to eliminate.

required
noswaps bool

If True, undo row swaps at the end.

False
col_subset Iterable[int] | None

Subset of columns to perform elimination on.

None

Returns:

Name Type Description
GF2 GF2

Matrix in row echelon form.

Raises:

Type Description
ValueError

If the matrix is not of GF2 type.

Source code in planqtn/linalg.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def gauss(
    mx: GF2, noswaps: bool = False, col_subset: Iterable[int] | None = None
) -> GF2:
    """Perform Gauss elimination on a GF2 matrix.

    Performs row reduction on a GF2 matrix to bring it to row echelon form.
    Optionally can restrict elimination to a subset of columns and control
    whether row swaps are preserved.

    Args:
        mx: Input GF2 matrix to eliminate.
        noswaps: If True, undo row swaps at the end.
        col_subset: Subset of columns to perform elimination on.

    Returns:
        GF2: Matrix in row echelon form.

    Raises:
        ValueError: If the matrix is not of GF2 type.
    """
    res: GF2 = deepcopy(mx)
    if not isinstance(mx, GF2):
        raise ValueError(f"Matrix is not of GF2 type, but instead {type(mx)}")
    if len(mx.shape) == 1:
        return res

    (rows, cols) = mx.shape

    idx = 0
    swaps = []

    if col_subset is None:
        col_subset = range(cols)

    for c in col_subset:
        assert c < cols, f"Column {c} does not exist in mx: \n{mx}"
        # if a col is all zero below, we leave it without increasing idx
        nzs = (np.flatnonzero(res[idx:, c]) + idx).tolist()
        if len(nzs) == 0:
            continue
        # find the first non-zero element in each column starting from idx
        pivot = nzs[0]

        # print(res)
        # print(f"col {c} idx {idx} pivot {pivot}")
        # print(res)

        if pivot != idx:
            # print("swapping")
            res[[pivot, idx]] = res[[idx, pivot]]
            swaps.append((pivot, idx))
            pivot = idx
        # ensure all other rows are zero in the pivot column
        # print(res)
        idxs = np.flatnonzero(res[:, c]).tolist()
        # print(idxs)
        idxs.remove(pivot)
        res[idxs] += res[pivot]

        idx += 1
        if idx == rows:
            break

    if noswaps:
        for pivot, idx in reversed(swaps):
            res[[pivot, idx]] = res[[idx, pivot]]

    return res

gauss_row_augmented(mx)

Perform Gauss elimination on a row-augmented matrix.

Creates a row-augmented matrix by appending the identity matrix to the right of the input matrix, then performs Gauss elimination. This is useful for computing matrix inverses and kernels.

Parameters:

Name Type Description Default
mx GF2

Input GF2 matrix.

required

Returns:

Name Type Description
GF2 GF2

Row-augmented matrix after Gauss elimination.

Source code in planqtn/linalg.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def gauss_row_augmented(mx: GF2) -> GF2:
    """Perform Gauss elimination on a row-augmented matrix.

    Creates a row-augmented matrix by appending the identity matrix to the right
    of the input matrix, then performs Gauss elimination. This is useful for
    computing matrix inverses and kernels.

    Args:
        mx: Input GF2 matrix.

    Returns:
        GF2: Row-augmented matrix after Gauss elimination.
    """
    res: GF2 = deepcopy(mx)
    return gauss(GF2(np.hstack([res, GF2.Identity(mx.shape[0])])))

invert(mx)

Invert a square GF2 matrix.

Computes the inverse of a square GF2 matrix using row-augmented Gauss elimination. The matrix must be non-singular (full rank) for the inverse to exist.

Parameters:

Name Type Description Default
mx GF2

Square GF2 matrix to invert.

required

Returns:

Name Type Description
GF2 GF2

The inverse of the input matrix.

Raises:

Type Description
ValueError

If the matrix is not GF2 type, not square, or singular.

Source code in planqtn/linalg.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def invert(mx: GF2) -> GF2:
    """Invert a square GF2 matrix.

    Computes the inverse of a square GF2 matrix using row-augmented Gauss elimination.
    The matrix must be non-singular (full rank) for the inverse to exist.

    Args:
        mx: Square GF2 matrix to invert.

    Returns:
        GF2: The inverse of the input matrix.

    Raises:
        ValueError: If the matrix is not GF2 type, not square, or singular.
    """
    if not isinstance(mx, GF2):
        raise ValueError(f"Matrix is not of GF2 type, but instead {type(mx)}")

    if len(mx.shape) == 1:
        raise ValueError("Only square matrices are allowed")
    (rows, cols) = mx.shape
    if rows != cols:
        raise ValueError(f"Can't invert a {rows} x {cols} non-square matrix.")
    n = rows
    a = gauss_row_augmented(mx)
    if not np.array_equal(a[:, :n], GF2.Identity(n)):
        raise ValueError(
            f"Matrix is singular, has rank: {np.linalg.matrix_rank(a[:,:n])}"
        )

    return GF2(a[:, n:])

right_kernel(mx)

Compute the right kernel (nullspace) of a GF2 matrix.

Computes a basis for the right kernel of the matrix, which consists of all vectors v such that mx * v = 0. Uses row-augmented Gauss elimination on the transpose of the matrix.

Parameters:

Name Type Description Default
mx GF2

Input GF2 matrix.

required

Returns:

Name Type Description
GF2 GF2

Matrix whose rows form a basis for the right kernel.

Source code in planqtn/linalg.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def right_kernel(mx: GF2) -> GF2:
    """Compute the right kernel (nullspace) of a GF2 matrix.

    Computes a basis for the right kernel of the matrix, which consists of
    all vectors v such that mx * v = 0. Uses row-augmented Gauss elimination
    on the transpose of the matrix.

    Args:
        mx: Input GF2 matrix.

    Returns:
        GF2: Matrix whose rows form a basis for the right kernel.
    """
    (rows, cols) = mx.shape
    a = gauss_row_augmented(mx.T)

    zero_rows = np.argwhere(np.all(a[..., :rows] == 0, axis=1)).flatten()
    if len(zero_rows) == 0:
        # an invertible matrix will have the trivial nullspace
        return GF2([GF2.Zeros(cols)])
    return GF2(a[zero_rows, rows:])