{"cells":[{"cell_type":"markdown","metadata":{},"source":["# VQE example with `pytket-qujax`\n","\n","**Download this notebook - {nb-download}`pytket-qujax_heisenberg_vqe.ipynb`**"]},{"cell_type":"markdown","metadata":{},"source":["See the docs for [qujax](https://cqcl.github.io/qujax/) and [pytket-qujax](https://cqcl.github.io/pytket-qujax/api/index.html)."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from jax import numpy as jnp, random, value_and_grad, jit\n","from pytket import Circuit\n","from pytket.circuit.display import render_circuit_jupyter\n","import matplotlib.pyplot as plt"]},{"cell_type":"markdown","metadata":{},"source":["## Let's start with a TKET circuit"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import qujax\n","from pytket.extensions.qujax.qujax_convert import tk_to_qujax"]},{"cell_type":"markdown","metadata":{},"source":["We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def get_circuit(n_qubits, depth):\n"," n_params = 2 * n_qubits * (depth + 1)\n"," param = jnp.zeros((n_params,))\n"," circuit = Circuit(n_qubits)\n"," k = 0\n"," for i in range(n_qubits):\n"," circuit.H(i)\n"," for i in range(n_qubits):\n"," circuit.Rx(param[k], i)\n"," k += 1\n"," for i in range(n_qubits):\n"," circuit.Ry(param[k], i)\n"," k += 1\n"," for _ in range(depth):\n"," for i in range(0, n_qubits - 1):\n"," circuit.CZ(i, i + 1)\n"," circuit.add_barrier(range(0, n_qubits))\n"," for i in range(n_qubits):\n"," circuit.Rx(param[k], i)\n"," k += 1\n"," for i in range(n_qubits):\n"," circuit.Ry(param[k], i)\n"," k += 1\n"," return circuit, n_params"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["n_qubits = 4\n","depth = 2\n","circuit, n_params = get_circuit(n_qubits, depth)\n","render_circuit_jupyter(circuit)"]},{"cell_type":"markdown","metadata":{},"source":["## Now let's invoke qujax\n","The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param_to_st = tk_to_qujax(circuit)"]},{"cell_type":"markdown","metadata":{},"source":["Let's try it out on some random parameters values. Be aware that's JAX's random number generator requires a `jax.random.PRNGkey` every time it's called - more info on that [here](https://jax.readthedocs.io/en/latest/jax.random.html).\n","Be aware that we still have convention where parameters are specified as multiples of $\\pi$ - that is in [0,2]."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["params = random.uniform(random.PRNGKey(0), shape=(n_params,), minval=0.0, maxval=2.0)\n","statetensor = param_to_st(params)\n","print(statetensor)\n","print(statetensor.shape)"]},{"cell_type":"markdown","metadata":{},"source":["Note that this function also has an optional second argument where an initiating `statetensor_in` can be provided. If it is not provided it will default to the all 0s state (as we use here)."]},{"cell_type":"markdown","metadata":{},"source":["We can obtain statevector by simply calling `.flatten()`"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["statevector = statetensor.flatten()\n","statevector.shape"]},{"cell_type":"markdown","metadata":{},"source":["And sampling probabilities by squaring the absolute value of the statevector"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["sample_probs = jnp.square(jnp.abs(statevector))\n","plt.bar(jnp.arange(statevector.size), sample_probs)"]},{"cell_type":"markdown","metadata":{},"source":["## Cost function"]},{"cell_type":"markdown","metadata":{},"source":["Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.\n","Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.\n","It takes three arguments as input\n","- `gate_seq_seq`: A list of string (or array) lists encoding the gates in each term of the Hamiltonian. I.e. `[['X','X'], ['Y','Y'], ['Z','Z']]` corresponds to $H = aX_iX_j + bY_kY_l + cZ_mZ_n$ with qubit indices $i,j,k,l,m,n$ specified in the second argument and coefficients $a,b,c$ specified in the third argument\n","- `qubit_inds_seq`: A list of integer lists encoding which qubit indices to apply the aforementioned gates. I.e. `[[0, 1],[0,1],[0,1]]`. Must have the same structure as `gate_seq_seq` above.\n","- `coefficients`: A list of floats encoding any coefficients in the Hamiltonian. I.e. `[2.3, 0.8, 1.2]` corresponds to $a=2.3,b=0.8,c=1.2$ above. Must have the same length as the two above arguments."]},{"cell_type":"markdown","metadata":{},"source":["More specifically let's consider the problem of finding the ground state of the quantum Heisenberg Hamiltonian"]},{"cell_type":"markdown","metadata":{},"source":["$$\n","\\begin{equation}\n","H = \\sum_{i=1}^{n_\\text{qubits}-1} X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}.\n","\\end{equation}\n","$$\n","\n","As described, we define the Hamiltonian via its gate strings, qubit indices and coefficients."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["hamiltonian_gates = [[\"X\", \"X\"], [\"Y\", \"Y\"], [\"Z\", \"Z\"]] * (n_qubits - 1)\n","hamiltonian_qubit_inds = [\n"," [int(i), int(i) + 1] for i in jnp.repeat(jnp.arange(n_qubits), 3)\n","]\n","coefficients = [1.0] * len(hamiltonian_qubit_inds)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["print(\"Gates:\\t\", hamiltonian_gates)\n","print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n","print(\"Coefficients:\\t\", coefficients)"]},{"cell_type":"markdown","metadata":{},"source":["Now let's get the Hamiltonian as a pure JAX function"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["st_to_expectation = qujax.get_statetensor_to_expectation_func(\n"," hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n",")"]},{"cell_type":"markdown","metadata":{},"source":["Let's check it works on the statetensor we've already generated."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["expected_val = st_to_expectation(statetensor)\n","expected_val"]},{"cell_type":"markdown","metadata":{},"source":["Now let's wrap the `param_to_st` and `st_to_expectation` together to give us an all in one `param_to_expectation` cost function."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param_to_expectation = lambda param: st_to_expectation(param_to_st(param))"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["param_to_expectation(params)"]},{"cell_type":"markdown","metadata":{},"source":["Sanity check that a different, randomly generated set of parameters gives us a new expected value."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["new_params = random.uniform(\n"," random.PRNGKey(1), shape=(n_params,), minval=0.0, maxval=2.0\n",")\n","param_to_expectation(new_params)"]},{"cell_type":"markdown","metadata":{},"source":["## Exact gradients within a VQE algorithm\n","The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`)."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["cost_and_grad = value_and_grad(param_to_expectation)"]},{"cell_type":"markdown","metadata":{},"source":["The `cost_and_grad` function returns a tuple with the exact cost value and exact gradient evaluated at the parameters."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["cost_and_grad(params)"]},{"cell_type":"markdown","metadata":{},"source":["## Now we have all the tools we need to design our VQE!\n","We'll just use vanilla gradient descent with a constant stepsize"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def vqe(init_param, n_steps, stepsize):\n"," params = jnp.zeros((n_steps, n_params))\n"," params = params.at[0].set(init_param)\n"," cost_vals = jnp.zeros(n_steps)\n"," cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n"," for step in range(1, n_steps):\n"," cost_val, cost_grad = cost_and_grad(params[step - 1])\n"," cost_vals = cost_vals.at[step].set(cost_val)\n"," new_param = params[step - 1] - stepsize * cost_grad\n"," params = params.at[step].set(new_param)\n"," print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")\n"," print(\"\\n\")\n"," return params, cost_vals"]},{"cell_type":"markdown","metadata":{},"source":["Ok enough talking, let's run (and whilst we're at it we'll time it too)"]},{"cell_type":"markdown","metadata":{},"source":["%time"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["vqe_params, vqe_cost_vals = vqe(params, n_steps=250, stepsize=0.01)"]},{"cell_type":"markdown","metadata":{},"source":["Let's plot the results..."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plt.plot(vqe_cost_vals)\n","plt.xlabel(\"Iteration\")\n","plt.ylabel(\"Cost\")"]},{"cell_type":"markdown","metadata":{},"source":["Pretty good!"]},{"cell_type":"markdown","metadata":{},"source":["## `jax.jit` speedup\n","One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["cost_and_grad = jit(cost_and_grad)"]},{"cell_type":"markdown","metadata":{},"source":["We'll demonstrate this using the second set of initial parameters we randomly generated (to be sure of no caching)."]},{"cell_type":"markdown","metadata":{},"source":["%time"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["new_vqe_params, new_vqe_cost_vals = vqe(new_params, n_steps=250, stepsize=0.01)"]},{"cell_type":"markdown","metadata":{},"source":["That's some speedup!\n","But let's also plot the training to be sure it converged correctly"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plt.plot(new_vqe_cost_vals)\n","plt.xlabel(\"Iteration\")\n","plt.ylabel(\"Cost\")"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"nbformat":4,"nbformat_minor":2}