### Install Project Dependencies Source: https://github.com/salcc/quantumtransformers/blob/main/README.md Installs the project and its dependencies using pip. This command should be run in the root folder of the project. ```bash pip install -e . ``` -------------------------------- ### JAX MNIST Dataloaders and Configuration Setup Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Imports JAX, necessary libraries, and the quantum_transformers JAX backend. It configures JAX for 64-bit precision and preallocates memory, then loads MNIST dataloaders. This setup is foundational for JAX examples. ```python import traceback import os os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # See https://github.com/google/jax/issues/12461#issuecomment-1256266598 import jaxlib from jax.config import config config.update("jax_enable_x64", True) import catalyst import quantum_transformers.qmlperfcomp.jax_backend as qpcjax train_dataloader, valid_dataloader = qpcjax.data.get_mnist_dataloaders(batch_size=64) ``` -------------------------------- ### PyTorch MNIST Dataloaders and Device Setup Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Imports PyTorch and the quantum_transformers library, configures the computation device (CPU/GPU), and retrieves MNIST dataloaders for training and validation. This setup is common for all PyTorch examples. ```python import torch import quantum_transformers.qmlperfcomp.torch_backend as qpctorch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") train_dataloader, valid_dataloader = qpctorch.data.get_mnist_dataloaders(batch_size=64, num_workers=4, pin_memory=True) ``` -------------------------------- ### JAX Data Loading and Setup Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Initializes JAX, configures it for 64-bit precision, and prints available JAX devices. It then loads the Swiss roll dataset into dataloaders for training and validation using the JAX backend. ```python import os os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # See https://github.com/google/jax/issues/12461#issuecomment-1256266598 import jaxlib import jax from jax.config import config config.update("jax_enable_x64", True) print(jax.devices()) import quantum_transformers.qmlperfcomp.jax_backend as qpcjax train_dataloader, valid_dataloader = qpcjax.data.get_swiss_roll_dataloaders(batch_size=4) ``` -------------------------------- ### PyTorch Data Loading and Setup Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Initializes PyTorch, sets the device (CPU/GPU), and loads the Swiss roll dataset into dataloaders for training and validation. It specifies batch size, number of workers, and memory pinning. ```python import traceback import pennylane as qml import torch import quantum_transformers.qmlperfcomp.torch_backend as qpctorch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") train_dataloader, valid_dataloader = qpctorch.data.get_swiss_roll_dataloaders(batch_size=4, num_workers=4, pin_memory=True) ``` -------------------------------- ### Project Setup and Helper Functions Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Initializes the project environment with necessary libraries like JAX, Flax, and TensorCircuit. Configures JAX for 64-bit precision and TensorCircuit for complex 128-bit data types. Includes a custom assertion function for comparing numerical results with a specified tolerance. ```python from collections.abc import Sequence import matplotlib as mpl import jax from jax import Array import jax.numpy as jnp import flax.linen as nn jax.config.update("jax_enable_x64", True) import tensorcircuit as tc tc.set_dtype("complex128") from quantum_transformers.datasets import get_mnist_dataloaders from quantum_transformers.training import train_and_evaluate data_dir = '/global/cfs/cdirs/m4392/salcc/data' mpl.rcParams['figure.dpi'] = 50 # To make circuit plots smaller easily jnp.set_printoptions(linewidth=150) K = tc.set_backend("jax") key = jax.random.PRNGKey(0) def assert_allclose(actual, desired, verbose=True, atol=1e-2): if verbose: print(actual, desired) assert jnp.allclose(actual, desired, atol=atol) ``` -------------------------------- ### Import Quantum Transformers Library Source: https://github.com/salcc/quantumtransformers/blob/main/README.md Demonstrates how to import the quantum_transformers library into your Python code after installation. ```python import quantum_transformers ``` -------------------------------- ### Generate and Draw Quantum Layer Circuit (Various Dimensions) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/tensorcircuit.ipynb Demonstrates creating a quantum layer circuit using `get_quantum_layer_circuit` with NumPy arrays for inputs and weights. The generated circuit is then drawn using matplotlib. Examples show variations in input and weight dimensions. ```python import numpy as np from quantum_transformers.qmlperfcomp.tc_common import get_quantum_layer_circuit # Example 1: 5 inputs, 2x5 weights inputs = np.random.uniform(size=(5,)) weights = np.random.uniform(size=(2, 5)) c = get_quantum_layer_circuit(inputs, weights) c.draw(output='mpl') ``` ```python import numpy as np from quantum_transformers.qmlperfcomp.tc_common import get_quantum_layer_circuit # Example 2: 2 inputs, 5x2 weights inputs = np.random.uniform(size=(2,)) weights = np.random.uniform(size=(5, 2)) c = get_quantum_layer_circuit(inputs, weights) c.draw(output='mpl') ``` ```python import numpy as np from quantum_transformers.qmlperfcomp.tc_common import get_quantum_layer_circuit # Example 3: 1 input, 1x1 weights inputs = np.random.uniform(size=(1,)) weights = np.random.uniform(size=(1, 1)) c = get_quantum_layer_circuit(inputs, weights) c.draw(output='mpl') ``` -------------------------------- ### Import Libraries and Setup Data Directory (Python) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/quark-gluon.ipynb Imports necessary JAX and custom quantum transformer libraries for the task. Sets the data directory path for dataset operations. ```python import jax from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_quark_gluon_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer from quantum_transformers.quantum_layer import get_circuit data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### Setup and Configuration Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/mnist.ipynb Imports necessary libraries like JAX and TensorFlow, configures TensorFlow to avoid GPU memory allocation, and sets a random seed for reproducible results. This block prepares the environment for the subsequent operations. ```python import jax import tensorflow as tf tf.config.set_visible_devices([], device_type='GPU') # Ensure TF does not see GPU and grab all GPU memory. tf.random.set_seed(42) # For reproducibility. from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_mnist_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### JAX Quantum Vision Transformer with PennyLane and Lightning Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Attempts to initialize and train a quantum Vision Transformer using PennyLane with a standard Lightning device in JAX. This configuration results in the same error as the Lightning-GPU setup. ```python try: model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.qubit") qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) except jaxlib.xla_extension.XlaRuntimeError as e: print(traceback.format_exc()) ``` -------------------------------- ### Load IMDb Dataset and Tokenizer Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/imdb.ipynb Loads the IMDb dataset, including training, validation, and test dataloaders, along with the vocabulary and tokenizer. It prints the vocabulary size and an example of a tokenized review. ```python (imdb_train_dataloader, imdb_val_dataloader, imdb_test_dataloader), vocab, tokenizer = get_imdb_dataloaders(batch_size=32, data_dir=data_dir, max_vocab_size=20_000, max_seq_len=512) print(f"Vocabulary size: {len(vocab)}") first_batch = next(iter(imdb_train_dataloader)) print(first_batch[0][0]) print(' '.join(map(bytes.decode, tokenizer.detokenize(first_batch[0])[0].numpy().tolist()))) ``` -------------------------------- ### Python JAX Vectorization Examples Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Demonstrates various JAX vectorization techniques (`vmap`, `scan`) for a sample function `f` operating on arrays. It compares manual, fully vectorized, batch vectorized, and optimized batch vectorized approaches for efficient computation. ```python # Example function that operates on a pair of 1D vectors and an additional parameter and returns a scalar def f(a, b, c): return jnp.dot(a, b) / c # Create a batch of sequences of 1D vectors batch_sequence = [jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([[7, 8], [9, 10], [11, 12]])] # Convert the batch of sequences of 1D vectors into a single 3D array batch_sequence_array = jnp.stack(batch_sequence) print(batch_sequence_array.shape) # Define the additional parameter that is the same for all batches c = 2 def manual_vectorized_f(batch_sequence_array, c): result = jnp.empty((batch_sequence_array.shape[0], batch_sequence_array.shape[1], batch_sequence_array.shape[1])) for b in range(batch_sequence_array.shape[0]): for i in range(batch_sequence_array.shape[1]): for j in range(batch_sequence_array.shape[1]): result = result.at[b, i, j].set(f(batch_sequence_array[b, i], batch_sequence_array[b, j], c)) return result def fully_vectorized_f(batch_sequence_array, c): vectorized_f = jax.vmap(jax.vmap(jax.vmap(f, in_axes=(0, None, None)), in_axes=(None, 0, None)), in_axes=(0, 0, None)) return vectorized_f(batch_sequence_array, batch_sequence_array, c) def batch_vectorized_f(batch_sequence_array, c): vectorized_f = jax.vmap(f, in_axes=(0, 0, None)) result = jnp.empty((batch_sequence_array.shape[0], batch_sequence_array.shape[1], batch_sequence_array.shape[1])) for i in range(batch_sequence_array.shape[1]): for j in range(batch_sequence_array.shape[1]): result = result.at[:, i, j].set(vectorized_f(batch_sequence_array[:, i], batch_sequence_array[:, j], c)) return result def optimized_batch_vectorized_f(batch_sequence_array, c): def body_fun(carry, inputs): i, j = inputs result_slice = jax.vmap(f, in_axes=(0, 0, None))(batch_sequence_array[:, i], batch_sequence_array[:, j], c) return carry.at[:, i, j].set(result_slice), None n = batch_sequence_array.shape[1] indices = jnp.array([(i, j) for i in range(n) for j in range(n)]) result_init = jnp.zeros((batch_sequence_array.shape[0], n, n)) result, _ = jax.lax.scan(body_fun, result_init, indices) return result # Compute the results result_manual = manual_vectorized_f(batch_sequence_array, c) result_full = fully_vectorized_f(batch_sequence_array, c) result_batch = batch_vectorized_f(batch_sequence_array, c) result_optimized_batch = optimized_batch_vectorized_f(batch_sequence_array, c) ``` -------------------------------- ### Load IMDb Dataset and Tokenizer Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/imdb.ipynb Loads the IMDb reviews dataset, including training, validation, and test dataloaders, along with the vocabulary and tokenizer. It also prints the vocabulary size and an example of a tokenized review. ```python (imdb_train_dataloader, imdb_valid_dataloader, imdb_test_dataloader), vocab, tokenizer = get_imdb_dataloaders(batch_size=32, data_dir=data_dir, max_vocab_size=20_000, max_seq_len=512) print(f"Vocabulary size: {len(vocab)}") first_batch = next(iter(imdb_train_dataloader)) print(first_batch[0][0]) print(' '.join(map(bytes.decode, tokenizer.detokenize(first_batch[0])[0].numpy().tolist()))) ``` -------------------------------- ### Run Hyperparameter Optimization Source: https://github.com/salcc/quantumtransformers/blob/main/hpopt/README.md Executes the `submit-ray-cluster.sh` script for hyperparameter optimization. Requires dataset name and number of trials, with an optional `--quantum` flag. The script may need modification for specific system setups. ```bash bash submit-ray-cluster.sh mnist 50 ``` -------------------------------- ### List Available JAX Devices Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/imdb.ipynb Prints information about the available JAX devices and their respective kinds. This helps in understanding the hardware configuration for computation. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### Print Available JAX Devices Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/imdb.ipynb Iterates through and prints the available JAX devices and their respective device kinds, providing insight into the computational hardware being used. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### Notebooks for Model Evaluation and Usage Source: https://github.com/salcc/quantumtransformers/blob/main/README.md The project includes a 'notebooks' directory containing various Jupyter notebooks for evaluating models, demonstrating usage, and showcasing performance. These are organized by dataset and include classical baselines and quantum implementations. ```text notebooks: the notebooks used for evaluating the models and showing their usage and performance. Each notebook is named after the dataset it uses. visualizations.ipynb: notebook visualizing the image datasets. classical/: classical counterparts as baselines. quantum/: the quantum transformers. qvit_cerrat_et_al.ipynb: notebook trying to reproduce the results of ["Quantum Vision Transformers" by Cerrat et al.](https://arxiv.org/abs/2106.03173), although unsuccessfully. ``` -------------------------------- ### Load Data, Get Sample, and Plot Image (Python) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/quark-gluon.ipynb Loads the quark-gluon dataset using a custom dataloader. It then extracts the first image from the training set and displays it using a plotting utility. ```python qg_train_dataloader, qg_val_dataloader, qg_test_dataloader = get_quark_gluon_dataloaders(batch_size=256, data_dir=data_dir) first_image = next(iter(qg_train_dataloader))[0][0] print(first_image.shape) plot_image(first_image, abs_log=True) ``` -------------------------------- ### JAX Quantum MLP (lightning.gpu, 5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum MLP with 5 features using PennyLane's 'lightning.gpu' device with JAX. Includes error handling for XlaRuntimeError, referencing a known compatibility issue. ```python try: model = qpcjax.quantum.MLP(5, qdevice="lightning.gpu") qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) except jaxlib.xla_extension.XlaRuntimeError as e: pass # The error is already printed See https://discuss.pennylane.ai/t/incompatible-function-arguments-error-on-lightning-qubit-with-jax/2900. ``` -------------------------------- ### Hyperparameter Optimization Scripts Source: https://github.com/salcc/quantumtransformers/blob/main/README.md The 'hpopt/' directory contains scripts dedicated to hyperparameter optimization for the quantum transformer models. It includes a README file with instructions on how to execute these scripts. ```text hpopt/: hyperparameter optimization scripts. The folder contains a README with instructions on how to run them. ``` -------------------------------- ### List Available JAX Devices Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/quark-gluon-resnet.ipynb Prints the available JAX devices and their respective kinds. This is useful for verifying the execution environment and ensuring JAX is utilizing the intended hardware. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### PennyLane Device Configurations Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/README.md Details specific device configurations used for PennyLane quantum computations, including the default qubit simulator and GPU-accelerated devices. ```Python # Default qubit simulator device # dev = qml.device('default.qubit', wires=num_qubits) # Lightning GPU device (requires specific backend setup) # dev = qml.device('lightning.gpu', wires=num_qubits) ``` -------------------------------- ### JAX Classical Vision Transformer Training Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Initializes a classical Vision Transformer model within the JAX framework and trains it using the provided dataloaders. This demonstrates the JAX implementation of a classical model. ```python model = qpcjax.classical.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) ``` -------------------------------- ### Get Semidiagonal Thetas (Python) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Calculates the parameters (thetas) required for a semi-diagonal quantum circuit. This function processes an input array `x` to derive these parameters, handling specific indexing and trigonometric operations. ```python def get_thetas_semidiagonal(x: Array): t = jnp.empty(len(x) - 1) t = t.at[1].set(jnp.arctan2(x[0], x[1])) t = t.at[-1].set(jnp.arctan2(x[-1], x[-2])) m = len(x) // 2 - 1 for i in range(2, m+1): t = t.at[i].set(jnp.arctan2(x[i-1], x[i] * jnp.cos(t[i-1]))) for i in range(len(x) - 3, m, -1): t = t.at[i].set(jnp.arctan2(x[i+1], x[i] * jnp.cos(t[i+1]))) t = t.at[0].set(jnp.arccos(x[m] / jnp.cos(t[m]))) thetas = -jnp.ones_like(t) thetas = thetas.at[0].set(t[0]) for i in range(1, len(x)//2): thetas = thetas.at[-2*i].set(t[i]) thetas = thetas.at[-2*i+1].set(t[-i]) return thetas key, subkey = jax.random.split(key) x = normalize(jax.random.uniform(subkey, (n,))) print(x) x_thetas_semidiagonal = get_thetas_semidiagonal(x) print(x_thetas_semidiagonal) ``` -------------------------------- ### Initialize QuantumOrthogonalTransformer (No Attention) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Demonstrates initializing the QuantumOrthogonalTransformer model with attention disabled. It shows preparing input data, splitting JAX random keys, initializing model variables, and printing their shapes. This is useful for understanding the model's structure before training. ```python x = jnp.ones((64, 28, 28, 3)) patch_wise_neural_network_model = QuantumOrthogonalTransformer(7, 10, attention=False) key, params_key = jax.random.split(key=key) variables = patch_wise_neural_network_model.init(params_key, x) print(jax.tree_map(lambda x: x.shape, variables)) ``` -------------------------------- ### Setup JAX and TensorFlow for Quark-Gluon Classification Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/quark-gluon.ipynb Imports necessary libraries like JAX, TensorFlow, and custom modules for data loading, training, and transformers. It configures TensorFlow to not utilize the GPU and sets a random seed for reproducibility. The data directory is also defined. ```python import jax import tensorflow as tf tf.config.set_visible_devices([], device_type='GPU') # Ensure TF does not see GPU and grab all GPU memory. tf.random.set_seed(42) # For reproducibility. from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_quark_gluon_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### Initialize QuantumOrthogonalTransformer (Vectorized Attention Disabled) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Shows how to initialize the QuantumOrthogonalTransformer model with vectorized attention disabled, which can help mitigate out-of-memory errors. The snippet includes model initialization, variable setup, and printing variable shapes, highlighting a memory-efficient configuration. ```python # vectorized attention gives out-of-memory error # not using vectorized attention is (probably) slower, but uses less memory orthogonal_transformer_model = QuantumOrthogonalTransformer(7, 10, vectorized_attention=False) key, params_key = jax.random.split(key=key) variables = orthogonal_transformer_model.init(params_key, x) print(jax.tree_map(lambda x: x.shape, variables)) ``` -------------------------------- ### JAX Quantum Vision Transformer with PennyLane Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Initializes a quantum Vision Transformer model using PennyLane within the JAX backend and trains it. This showcases a standard quantum approach in JAX. ```python model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) ``` -------------------------------- ### JAX Quantum MLP (TensorCircuit, 5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum MLP with 5 features using the TensorCircuit backend with JAX. The training process runs for 50 epochs. ```python model = qpcjax.quantum.MLP(5, qml_backend="tensorcircuit") qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### Re-train and Evaluate QuantumOrthogonalTransformer Model (JIT Disabled) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/qvit_cerrat_et_al.ipynb Demonstrates training and evaluation of the QuantumOrthogonalTransformer model with JAX's JIT compilation disabled. This setup is typically used for debugging purposes to inspect intermediate values or execution flow more easily, despite the performance impact. ```python model = QuantumOrthogonalTransformer(7, 10, attention=False) train_and_evaluate(model, mnist_train_dataloader, mnist_valid_dataloader, mnist_test_dataloader, num_classes=10, num_epochs=5, debug=True) ``` -------------------------------- ### JAX Quantum MLP (TensorCircuit, 20 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum MLP with 20 features using the TensorCircuit backend with JAX. The training process runs for 50 epochs. ```python model = qpcjax.quantum.MLP(20, qml_backend="tensorcircuit") qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### Import Libraries and Set Data Directory Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/imdb.ipynb Imports essential JAX and quantum transformer libraries, including data loading, training, and transformer components. It also defines the path to the dataset directory. ```python import jax from quantum_transformers.datasets import get_imdb_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import Transformer from quantum_transformers.quantum_layer import get_circuit data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### Train Vision Transformer with PennyLane/Lightning/Catalyst Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb This snippet demonstrates training a quantum Vision Transformer model using PennyLane with the Lightning backend and Catalyst for compilation. It includes error handling for the training process, which requires pre-defined dataloaders and the qpcjax library. ```Python import qpcjax import traceback try: model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.qubit", use_catalyst=True) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) except Exception as e: print(traceback.format_exc()) ``` -------------------------------- ### Quantum ML Framework Performance Comparison Source: https://github.com/salcc/quantumtransformers/blob/main/README.md This section highlights the subproject focused on comparing the performance of different quantum machine learning frameworks, specifically PennyLane and TensorCircuit. The evaluation indicates that TensorCircuit offers significantly better performance. ```text quantum_transformers/qmlperfcomp/: subproject to compare the performance of different quantum machine learning frameworks. In particular, I evaluated [PennyLane](https://pennylane.ai/) and [TensorCircuit](https://tensorcircuit.readthedocs.io/) (spoiler: TensorCircuit is much faster). ``` -------------------------------- ### JAX Quantum MLP (default.qubit.jax, 5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum Multi-Layer Perceptron (MLP) model with 5 input features using PennyLane's 'default.qubit.jax' device. The training runs for 50 epochs. ```python model = qpcjax.quantum.MLP(5) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### PyTorch Quantum MLP (lightning.gpu, 20 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum Multi-Layer Perceptron (MLP) model with 20 input features using PennyLane's 'lightning.gpu' device. The training runs for 50 epochs. ```python model = qpctorch.quantum.MLP(20, qdevice="lightning.gpu") qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50) ``` -------------------------------- ### JAX Quantum MLP (default.qubit.jax, 20 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum Multi-Layer Perceptron (MLP) model with 20 input features using PennyLane's 'default.qubit.jax' device. The training runs for 50 epochs. ```python model = qpcjax.quantum.MLP(20) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### PyTorch Quantum MLP (lightning.gpu, 5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum Multi-Layer Perceptron (MLP) model with 5 input features using PennyLane's 'lightning.gpu' device. The training runs for 50 epochs. ```python model = qpctorch.quantum.MLP(5, qdevice="lightning.gpu") qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50) ``` -------------------------------- ### Display JAX Device Information Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/quark-gluon.ipynb Iterates through all available JAX devices and prints their names and device kinds. This helps in understanding the computational hardware available for training. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### JAX Classical MLP (5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a classical Multi-Layer Perceptron (MLP) model with 5 input features using JAX. The training process runs for 50 epochs. ```python model = qpcjax.classical.MLP(5) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### PyTorch Quantum Vision Transformer with PennyLane Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Initializes a quantum Vision Transformer model using PennyLane as the quantum backend and trains it with the PyTorch dataloaders. This demonstrates a standard quantum implementation. ```python model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3) qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device) ``` -------------------------------- ### JAX Classical MLP (20 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a classical Multi-Layer Perceptron (MLP) model with 20 input features using JAX. The training process runs for 50 epochs. ```python model = qpcjax.classical.MLP(20) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50) ``` -------------------------------- ### PyTorch Classical Vision Transformer Training Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Initializes a classical Vision Transformer model with specified parameters and trains it using the prepared PyTorch dataloaders. This serves as a baseline for quantum implementations. ```python model = qpctorch.classical.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3) qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device) ``` -------------------------------- ### PyTorch Quantum Vision Transformer with TensorCircuit Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Initializes a quantum Vision Transformer model specifying TensorCircuit as the quantum backend and trains it using PyTorch. This explores an alternative quantum simulation framework. ```python model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qml_backend="tensorcircuit") qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device) ``` -------------------------------- ### JAX Quantum Vision Transformer with PennyLane and Lightning-GPU Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Attempts to initialize and train a quantum Vision Transformer using PennyLane with a Lightning-GPU device in JAX. This configuration is noted as not working due to compatibility issues. ```python try: model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.gpu") qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) except jaxlib.xla_extension.XlaRuntimeError as e: print(traceback.format_exc()) ``` -------------------------------- ### Instantiate and Train ResNet15 (Version 2) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/quark-gluon-resnet.ipynb Instantiates the second version of the ResNet15 model with 2 output classes and begins the training process. It uses the same dataloaders and training utility as the first version, running for 30 epochs. ```python model = ResNet15(num_classes=2) train_and_evaluate(model, qg_train_dataloader, qg_val_dataloader, qg_test_dataloader, num_classes=2, num_epochs=30) ``` -------------------------------- ### List Available JAX Devices Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/electron-photon.ipynb Iterates through all available JAX devices and prints their type and kind. This is useful for understanding the hardware environment where computations will be performed. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### Load and Plot Datasets Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/plots.ipynb Loads Swiss Roll and MNIST datasets and optionally plots them. This snippet demonstrates data preparation for performance benchmarking. ```python import matplotlib.pyplot as plt import quantum_transformers.qmlperfcomp.jax_backend as qpcjax # Load and plot Swiss Roll dataset qpcjax.data.get_swiss_roll_dataloaders(plot=True) # Load and plot MNIST dataset qpcjax.data.get_mnist_dataloaders(plot=True) ``` -------------------------------- ### Instantiate and Train ResNet15 (Version 1) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/quark-gluon-resnet.ipynb Instantiates the ResNet15 model with 2 output classes and initiates the training process using the prepared dataloaders and the `train_and_evaluate` utility for 30 epochs. ```python model = ResNet15(num_classes=2) train_and_evaluate(model, qg_train_dataloader, qg_val_dataloader, qg_test_dataloader, num_classes=2, num_epochs=30) ``` -------------------------------- ### Project Library Code Source: https://github.com/salcc/quantumtransformers/blob/main/README.md The core library code for the quantum transformers is located in the 'quantum_transformers/' directory. This includes modules for loading datasets and training models. ```text quantum_transformers/: the library code for the quantum transformers, as well as for loading the data (datasets.py) and training the models (training.py). ``` -------------------------------- ### PyTorch Quantum MLP (TensorCircuit, 5 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum MLP with 5 features using the TensorCircuit backend. The training process runs for 50 epochs. ```python model = qpctorch.quantum.MLP(5, qml_backend="tensorcircuit") qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50) ``` -------------------------------- ### PyTorch Quantum MLP (TensorCircuit, 20 features) Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/mlp.ipynb Trains and evaluates a quantum MLP with 20 features using the TensorCircuit backend. The training process runs for 50 epochs. ```python model = qpctorch.quantum.MLP(20, qml_backend="tensorcircuit") qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50) ``` -------------------------------- ### Display JAX Devices (Python) Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/quark-gluon.ipynb Iterates through available JAX devices and prints their type and kind. This helps in understanding the computational resources being used. ```python for d in jax.devices(): print(d, d.device_kind) ``` -------------------------------- ### JAX Quantum Vision Transformer with PennyLane, Lightning-GPU, and Catalyst Source: https://github.com/salcc/quantumtransformers/blob/main/quantum_transformers/qmlperfcomp/vit.ipynb Attempts to initialize and train a quantum Vision Transformer using PennyLane with Lightning-GPU and Catalyst enabled in JAX. This combination is explicitly stated as not supported. ```python try: model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice="lightning.gpu", use_catalyst=True) qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10) except catalyst.CompileError as e: print(traceback.format_exc()) ``` -------------------------------- ### Initialize JAX and TensorFlow Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/classical/electron-photon.ipynb Initializes the JAX and TensorFlow environments for the electron-photon classification task. It configures TensorFlow to not use the GPU and sets a random seed for reproducibility. Necessary modules from quantum_transformers are also imported. ```Python import jax import tensorflow as tf tf.config.set_visible_devices([], device_type='GPU') # Ensure TF does not see GPU and grab all GPU memory. tf.random.set_seed(42) # For reproducibility. from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_electron_photon_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### Import Libraries and Set Data Directory Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/electron-photon.ipynb Imports necessary JAX and quantum transformer libraries, including utilities for plotting, datasets, training, and transformer models. It also defines the path to the dataset directory. ```python import jax from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_electron_photon_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer from quantum_transformers.quantum_layer import get_circuit data_dir = '/global/cfs/cdirs/m4392/salcc/data' ``` -------------------------------- ### Import necessary libraries for Quantum Transformers Source: https://github.com/salcc/quantumtransformers/blob/main/notebooks/quantum/mnist.ipynb Imports core JAX libraries and specific modules from the quantum_transformers package for data loading, model definition, and training. It also imports a utility for plotting images and defines the data directory. ```python import jax from quantum_transformers.utils import plot_image from quantum_transformers.datasets import get_mnist_dataloaders from quantum_transformers.training import train_and_evaluate from quantum_transformers.transformers import VisionTransformer from quantum_transformers.quantum_layer import get_circuit data_dir = '/global/cfs/cdirs/m4392/salcc/data' ```