From d427160a288626c523059d3c739b4ddcb89ac5f7 Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 12 Jun 2026 03:39:40 +0000 Subject: [PATCH 1/2] Add test cases for untested ops/batch_util.py functions This is a straightforward addition of a few unit test cases to increase test coverage. --- .../core/ops/batch_util_test.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index c702338d6..31aec9aa2 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -330,6 +330,66 @@ def test_pauli_sum_collector_collect(self): self.assertAlmostEqual(collector.estimated_energy(), expected_energy) + def test_pauli_sum_collector_next_job(self): + """Test the next_job method of TFQPauliSumCollector.""" + qubit = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(qubit)) + observable = cirq.Z(qubit) + 2.0 * cirq.X(qubit) + + collector = batch_util.TFQPauliSumCollector( + circuit, observable, samples_per_term=10) + + job1 = collector.next_job() + self.assertIsNotNone(job1) + self.assertEqual(job1.repetitions, 10) + + job2 = collector.next_job() + self.assertIsNotNone(job2) + self.assertEqual(job2.repetitions, 10) + + job3 = collector.next_job() + self.assertIsNone(job3) + + def test_pauli_sum_collector_on_job_result(self): + """Test the on_job_result method of TFQPauliSumCollector.""" + qubit = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(qubit)) + observable = cirq.Z(qubit) + + collector = batch_util.TFQPauliSumCollector( + circuit, observable, samples_per_term=5) + + job = collector.next_job() + + class FakeResult: + def histogram(self, key, fold_func): + return {0: 2, 1: 3} + + collector.on_job_result(job, FakeResult()) + + job_id = job.tag + self.assertEqual(collector._zeros[job_id], 2) + self.assertEqual(collector._ones[job_id], 3) + + def test_pauli_sum_collector_estimated_energy(self): + """Test the estimated_energy method of TFQPauliSumCollector.""" + qubit = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(qubit)) + observable = 3.0 * cirq.Z(qubit) + 2.0 * cirq.I(qubit) + + collector = batch_util.TFQPauliSumCollector( + circuit, observable, samples_per_term=5) + + job = collector.next_job() + + class FakeResult: + def histogram(self, key, fold_func): + return {0: 2, 1: 3} + + collector.on_job_result(job, FakeResult()) + + self.assertAlmostEqual(collector.estimated_energy(), 1.4) + if __name__ == '__main__': tf.test.main() From c59f1edfc9fa7f7707902506acf3413a0dc12b75 Mon Sep 17 00:00:00 2001 From: mhucka Date: Fri, 12 Jun 2026 03:42:13 +0000 Subject: [PATCH 2/2] Run through scripts/format_all.sh --- tensorflow_quantum/core/ops/batch_util_test.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index 31aec9aa2..0778dc881 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -336,8 +336,9 @@ def test_pauli_sum_collector_next_job(self): circuit = cirq.Circuit(cirq.X(qubit)) observable = cirq.Z(qubit) + 2.0 * cirq.X(qubit) - collector = batch_util.TFQPauliSumCollector( - circuit, observable, samples_per_term=10) + collector = batch_util.TFQPauliSumCollector(circuit, + observable, + samples_per_term=10) job1 = collector.next_job() self.assertIsNotNone(job1) @@ -356,12 +357,14 @@ def test_pauli_sum_collector_on_job_result(self): circuit = cirq.Circuit(cirq.X(qubit)) observable = cirq.Z(qubit) - collector = batch_util.TFQPauliSumCollector( - circuit, observable, samples_per_term=5) + collector = batch_util.TFQPauliSumCollector(circuit, + observable, + samples_per_term=5) job = collector.next_job() class FakeResult: + def histogram(self, key, fold_func): return {0: 2, 1: 3} @@ -377,12 +380,14 @@ def test_pauli_sum_collector_estimated_energy(self): circuit = cirq.Circuit(cirq.X(qubit)) observable = 3.0 * cirq.Z(qubit) + 2.0 * cirq.I(qubit) - collector = batch_util.TFQPauliSumCollector( - circuit, observable, samples_per_term=5) + collector = batch_util.TFQPauliSumCollector(circuit, + observable, + samples_per_term=5) job = collector.next_job() class FakeResult: + def histogram(self, key, fold_func): return {0: 2, 1: 3}