Hello, I saw your presentation at IEEE Quantum Week and decided to start playing around with your code! I have come across an issue that I am not sure how to resolve. When inserting a quantum layer into my torch model, I get the “RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed” error. I have isolated the issue to the quantum circuit in my code but I’m not sure where it’s originating or how to resolve it. Am I missing something from my quantum layer code?
class QuantumLayer(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
self.measure = tq.MeasureAll(tq.PauliZ)
# Trainable gates
self.ry01 = tq.RY(has_params=True, trainable=True)
self.ry02 = tq.RY(has_params=True, trainable=True)
self.ry11 = tq.RY(has_params=True, trainable=True)
self.ry12 = tq.RY(has_params=True, trainable=True)
self.ry21 = tq.RY(has_params=True, trainable=True)
self.ry22 = tq.RY(has_params=True, trainable=True)
self.ry31 = tq.RY(has_params=True, trainable=True)
self.ry32 = tq.RY(has_params=True, trainable=True)
@tq.static_support
def forward(self, x):
final_result = []
for sample in range(x.shape[0]):
# Feature Encoding
tqf.h(self.q_device,wires=list(range(self.n_wires)), static=self.static_mode, parent_graph=self.graph)
for row in range(self.n_wires):
tqf.rz(self.q_device, wires=row, params=torch.tensor([x[sample,row]]),static=self.static_mode, parent_graph=self.graph)
for row in range(self.n_wires-1):
for row2 in range(self.n_wires-row-1):
target = row + row2 + 1
rotation = x[sample,row] + x[sample,row2 + 1]
tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
tqf.rz(self.q_device,wires=target,params=torch.tensor([rotation]),static=self.static_mode, parent_graph=self.graph)
tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
# Variational Portion
self.ry01(self.q_device,wires=0)
self.ry11(self.q_device,wires=1)
self.ry21(self.q_device,wires=2)
self.ry31(self.q_device,wires=3)
for row in range(self.n_wires-1):
for row2 in range(self.n_wires-row-1):
target = row + row2 + 1
tqf.cx(self.q_device,wires=[row,target],static=self.static_mode, parent_graph=self.graph)
self.ry02(self.q_device,wires=0)
self.ry12(self.q_device,wires=1)
self.ry22(self.q_device,wires=2)
self.ry32(self.q_device,wires=3)
result = self.measure(self.q_device)
final_result.append(result)
final_result = torch.cat((final_result),dim=0).float()
return final_result
Thank you!