From 228a9e0570a2954eb74b516c6d0168367ffd53cb Mon Sep 17 00:00:00 2001 From: William Galvin Date: Tue, 26 Dec 2023 17:25:35 -0800 Subject: [PATCH] Save thetas from pennylane routine --- pennylane/vqe_pennylane/main.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pennylane/vqe_pennylane/main.py b/pennylane/vqe_pennylane/main.py index 5b81a37..6774bff 100644 --- a/pennylane/vqe_pennylane/main.py +++ b/pennylane/vqe_pennylane/main.py @@ -3,6 +3,7 @@ import argparse from pennylane import numpy as np +import numpy import pennylane as qml @@ -28,6 +29,11 @@ def main(): default=1e-6, help="Convergence threshdold" ) + + parser.add_argument( + "--output-dir", + default="." + ) args = parser.parse_args() @@ -70,13 +76,13 @@ def cost_fn(param): energy = [cost_fn(theta)] - angle = [theta] + angle = numpy.zeros((args.max_iter, theta.shape[0])) for n in range(args.max_iter): + angle[n] = theta theta, prev_energy = opt.step_and_cost(cost_fn, theta) energy.append(cost_fn(theta)) - angle.append(theta) conv = np.abs(energy[-1] - prev_energy) @@ -102,8 +108,11 @@ def cost_fn(param): plt.xticks(fontsize=12) plt.yticks(fontsize=12) - plt.savefig("plot.png") - print("\nPlot saved at plot.png") + plt.savefig(f"{args.output_dir}/plot.png") + print(f"\nPlot saved at {args.output_dir}/plot.png") + + numpy.save(f"{args.output_dir}/thetas.npy", angle) + print(f"Thetas saved at {args.output_dir}/thetas.npy") if __name__ == "__main__":