Runtime Error with loss.backward()

Hello, I saw your presentation at IEEE Quantum Week and decided to start playing around with your code! I have come across an issue that I am not sure how to resolve. When inserting a quantum layer into my torch model, I get the “RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed” error. I have isolated the issue to the quantum circuit in my code but I’m not sure where it’s originating or how to resolve it. Am I missing something from my quantum layer code?

class QuantumLayer(tq.QuantumModule):
    def __init__(self):
        super().__init__()
        self.n_wires = 4
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        self.measure = tq.MeasureAll(tq.PauliZ)

        # Trainable gates
        self.ry01 = tq.RY(has_params=True, trainable=True)
        self.ry02 = tq.RY(has_params=True, trainable=True)
        self.ry11 = tq.RY(has_params=True, trainable=True)
        self.ry12 = tq.RY(has_params=True, trainable=True)
        self.ry21 = tq.RY(has_params=True, trainable=True)
        self.ry22 = tq.RY(has_params=True, trainable=True)
        self.ry31 = tq.RY(has_params=True, trainable=True)
        self.ry32 = tq.RY(has_params=True, trainable=True)

    @tq.static_support
    def forward(self, x):
        final_result = []
        for sample in range(x.shape[0]):
            # Feature Encoding
            tqf.h(self.q_device,wires=list(range(self.n_wires)), static=self.static_mode, parent_graph=self.graph)
            for row in range(self.n_wires):
                tqf.rz(self.q_device, wires=row, params=torch.tensor([x[sample,row]]),static=self.static_mode, parent_graph=self.graph)      
            for row in range(self.n_wires-1):
                for row2 in range(self.n_wires-row-1):
                    target = row + row2 + 1
                    rotation = x[sample,row] + x[sample,row2 + 1]
                    tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
                    tqf.rz(self.q_device,wires=target,params=torch.tensor([rotation]),static=self.static_mode, parent_graph=self.graph)
                    tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
            # Variational Portion
            self.ry01(self.q_device,wires=0)
            self.ry11(self.q_device,wires=1)
            self.ry21(self.q_device,wires=2)
            self.ry31(self.q_device,wires=3)
            for row in range(self.n_wires-1):
                for row2 in range(self.n_wires-row-1):
                    target = row + row2 + 1
                    tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
            self.ry02(self.q_device,wires=0)
            self.ry12(self.q_device,wires=1)
            self.ry22(self.q_device,wires=2)
            self.ry32(self.q_device,wires=3)

            result = self.measure(self.q_device)
            final_result.append(result)
        final_result = torch.cat((final_result),dim=0).float()
        return final_result

Thank you!

Dear cabb,
Thank you for your feedback! Could you provide the full script for me to reproduce the bug?

Hi Hanrui,

I managed to resolve the issue by adding self.q_device.reset_states(1) to the beginning of the forward definition. My assumption is that it was trying to calculate the second batch using the final qubit states of the previous batch as the starting states of the new batch. I’ve noticed that the line is included in some examples but not others like “mnist_example.py”. What is the criteria for requiring the line to reset the states?

Thanks,

Cliff

Hi Cliff, the reset_states here actually is initializing the states in the batch mode, which means the quantum device is storing multiple copies of the quantum state. The reason why we use this is that in some QML tasks, we want to encode a batch of classical information to quantum state, such as a batch of images, so their will be one encoded quantum state for each image.

The reset_states should be called when you want to apply the same quantum gate to a batch of different quantum states.

For the newest version of tq, we explicitly initialize the batch size as 1 as here: torchquantum/devices.py at 555bd3fa52980429e9b7cc6a6c66b0c9f5d2b49d · mit-han-lab/torchquantum · GitHub

so if you don’t need batch mode, the reset_states function calling is not necessary.

Please let me know if you have any question. Meanwhile you can also join slack to get more supports:

https://join.slack.com/t/torchquantum/shared_invite/zt-1ghuf283a-OtP4mCPJREd~367VX~TaQQ