{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Symbolic circuits with `pytket-qujax`\n","\n","**Download this notebook - {nb-download}`pytket-qujax_qaoa.ipynb`**\n","\n","In this notebook we will show how to manipulate symbolic circuits with the `pytket-qujax` extension. In particular, we will consider a QAOA and an Ising Hamiltonian."]},{"cell_type":"markdown","metadata":{},"source":["See the docs for [qujax](https://cqcl.github.io/qujax/) and [pytket-qujax](https://tket.quantinuum.com/extensions/pytket-qujax/)."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from pytket import Circuit\n","from pytket.circuit.display import render_circuit_jupyter\n","from jax import numpy as jnp, random, value_and_grad, jit\n","from sympy import Symbol\n","import matplotlib.pyplot as plt"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import qujax\n","from pytket.extensions.qujax import tk_to_qujax"]},{"cell_type":"markdown","metadata":{},"source":["## QAOA\n","The Quantum Approximate Optimization Algorithm (QAOA), first introduced by [Farhi et al.](https://arxiv.org/pdf/1411.4028.pdf), is a quantum variational algorithm used to solve optimization problems. It consists of a unitary $U(\\beta, \\gamma)$ formed by alternate repetitions of $U(\\beta)=e^{-i\\beta H_B}$ and $U(\\gamma)=e^{-i\\gamma H_P}$, where $H_B$ is the mixing Hamiltonian and $H_P$ the problem Hamiltonian. The goal is to find the optimal parameters that minimize $H_P$.\n","Given a depth $d$, the expression of the final unitary is $U(\\beta, \\gamma) = U(\\beta_d)U(\\gamma_d)\\cdots U(\\beta_1)U(\\gamma_1)$. Notice that for each repetition the parameters are different.\n","\n","## Problem Hamiltonian\n","QAOA uses a problem dependent ansatz. Therefore, we first need to know the problem that we want to solve. In this case we will consider an Ising Hamiltonian with only $Z$ interactions. Given a set of pairs (or qubit indices) $E$, the problem Hamiltonian will be:\n","\n","$$\n","\\begin{equation}\n","H_P = \\sum_{(i, j) \\in E}\\alpha_{ij}Z_iZ_j,\n","\\end{equation}\n","$$\n","\n","where $\\alpha_{ij}$ are the coefficients.\n","Let's build our problem Hamiltonian with random coefficients and a set of pairs for a given number of qubits:"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["n_qubits = 4\n","hamiltonian_qubit_inds = [(0, 1), (1, 2), (0, 2), (1, 3)]\n","hamiltonian_gates = [[\"Z\", \"Z\"]] * (len(hamiltonian_qubit_inds))"]},{"cell_type":"markdown","metadata":{},"source":["Notice that in order to use the random package from jax we first need to define a seeded key"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["seed = 13\n","key = random.PRNGKey(seed)\n","coefficients = random.uniform(key, shape=(len(hamiltonian_qubit_inds),))"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["print(\"Gates:\\t\", hamiltonian_gates)\n","print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n","print(\"Coefficients:\\t\", coefficients)"]},{"cell_type":"markdown","metadata":{},"source":["## Variational Circuit\n","Before constructing the circuit, we still need to select the mixing Hamiltonian. In our case, we will be using $X$ gates in each qubit, so $H_B = \\sum_{i=1}^{n}X_i$, where $n$ is the number of qubits. Notice that the unitary $U(\\beta)$, given this mixing Hamiltonian, is an $X$ rotation in each qubit with angle $\\beta$.\n","As for the unitary corresponding to the problem Hamiltonian, $U(\\gamma)$, it has the following form:\n","\n","$$\n","\\begin{equation}\n","U(\\gamma)=\\prod_{(i, j) \\in E}e^{-i\\gamma\\alpha_{ij}Z_i Z_j}\n","\\end{equation}\n","$$\n","\n","The operation $e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$ can be performed using two CNOT gates with qubit $i$ as control and qubit $j$ as target and a $Z$ rotation in qubit $j$ in between them, with angle $\\gamma\\alpha_{ij}$.\n","Finally, the initial state used, in general, with the QAOA is an equal superposition of all the basis states. This can be achieved adding a first layer of Hadamard gates in each qubit at the beginning of the circuit."]},{"cell_type":"markdown","metadata":{},"source":["With all the building blocks, let's construct the symbolic circuit using tket. Notice that in order to define the parameters, we use the ```Symbol``` object from the `sympy` package. More info can be found in this [documentation](https://tket.quantinuum.com/user-manual/manual_circuit.html#symbolic-circuits). In order to later convert the circuit to qujax, we need to return the list of symbolic parameters as well."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def qaoa_circuit(n_qubits, depth):\n"," circuit = Circuit(n_qubits)\n"," p_keys = []\n","\n"," # Initial State\n"," for i in range(n_qubits):\n"," circuit.H(i)\n"," for d in range(depth):\n"," # Hamiltonian unitary\n"," gamma_d = Symbol(f\"γ_{d}\")\n"," for index in range(len(hamiltonian_qubit_inds)):\n"," pair = hamiltonian_qubit_inds[index]\n"," coef = coefficients[index]\n"," circuit.CX(pair[0], pair[1])\n"," circuit.Rz(gamma_d * coef, pair[1])\n"," circuit.CX(pair[0], pair[1])\n"," circuit.add_barrier(range(0, n_qubits))\n"," p_keys.append(gamma_d)\n","\n"," # Mixing unitary\n"," beta_d = Symbol(f\"β_{d}\")\n"," for i in range(n_qubits):\n"," circuit.Rx(beta_d, i)\n"," p_keys.append(beta_d)\n"," return circuit, p_keys"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["depth = 3\n","circuit, keys = qaoa_circuit(n_qubits, depth)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["keys"]},{"cell_type":"markdown","metadata":{},"source":["Let's check the circuit:"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["render_circuit_jupyter(circuit)"]},{"cell_type":"markdown","metadata":{},"source":["## Now for `qujax`\n","The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us. However, in order to convert a symbolic circuit we first need to define the `symbol_map`. This object maps each symbol key to their corresponding index. In our case, since the object `keys` contains the symbols in the correct order, we can simply construct the dictionary as follows:"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["symbol_map = {keys[i]: i for i in range(len(keys))}"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["symbol_map"]},{"cell_type":"markdown","metadata":{},"source":["Then, we invoke the `tk_to_qujax` with both the circuit and the symbolic map."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param_to_st = tk_to_qujax(circuit, symbol_map=symbol_map)"]},{"cell_type":"markdown","metadata":{},"source":["And we also construct the expectation map using the problem Hamiltonian via qujax:"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["st_to_expectation = qujax.get_statetensor_to_expectation_func(\n"," hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param_to_expectation = lambda param: st_to_expectation(param_to_st(param))"]},{"cell_type":"markdown","metadata":{},"source":["## Training process\n","We construct a function that, given a parameter vector, returns the value of the cost function and the gradient.\n","We also `jit` to avoid recompilation, this means that the expensive `cost_and_grad` function is compiled once into a very fast XLA (C++) function which is then executed at each iteration. Alternatively, we could get the same speedup by replacing our `for` loop with `jax.lax.scan`. You can read more about JIT compilation in the [JAX documentation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["cost_and_grad = jit(value_and_grad(param_to_expectation))"]},{"cell_type":"markdown","metadata":{},"source":["For the training process we'll use vanilla gradient descent with a constant stepsize:"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["seed = 123\n","key = random.PRNGKey(seed)\n","init_param = random.uniform(key, shape=(len(symbol_map),))"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["n_steps = 150\n","stepsize = 0.01"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param = init_param"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["cost_vals = jnp.zeros(n_steps)\n","cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["for step in range(1, n_steps):\n"," cost_val, cost_grad = cost_and_grad(param)\n"," cost_vals = cost_vals.at[step].set(cost_val)\n"," param = param - stepsize * cost_grad\n"," print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")"]},{"cell_type":"markdown","metadata":{},"source":["Let's visualise the gradient descent"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plt.plot(cost_vals)\n","plt.xlabel(\"Iteration\")\n","plt.ylabel(\"Cost\")"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"nbformat":4,"nbformat_minor":2}