|
1 | 1 | import os
|
| 2 | +import sys |
2 | 3 | import numpy as np
|
3 | 4 | import scipy as sp
|
4 | 5 | import dimod
|
5 | 6 | import minorminer
|
6 |
| -try: |
7 |
| - import gurobipy |
8 |
| -except ImportError: |
9 |
| - pass |
10 | 7 | from tqdm import tqdm
|
11 | 8 | from concurrent.futures import ThreadPoolExecutor
|
12 | 9 | from dwave.samplers import SimulatedAnnealingSampler
|
13 | 10 | from dwave.system import LeapHybridSampler
|
14 | 11 | from dwave.system import DWaveSampler, FixedEmbeddingComposite
|
15 | 12 |
|
| 13 | +try: |
| 14 | + import gurobipy |
| 15 | +except ImportError: |
| 16 | + pass |
| 17 | + |
16 | 18 |
|
17 | 19 | class QUnfoldQUBO:
|
18 | 20 | def __init__(self, response, measured, lam=0.0):
|
@@ -180,53 +182,42 @@ def compute_energy(self, x):
|
180 | 182 | energy = self.dwave_bqm.energy(sample=xbin)
|
181 | 183 | return energy
|
182 | 184 |
|
183 |
| - def solve_gurobi_integer(self): |
184 |
| - try: |
| 185 | + if "gurobipy" in sys.modules: |
| 186 | + |
| 187 | + def solve_gurobi_integer(self): |
185 | 188 | model = gurobipy.Model()
|
186 |
| - except Exception as e: |
187 |
| - import traceback |
188 |
| - import sys |
189 |
| - print(traceback.format_exc()) |
190 |
| - print("\n\nGurobi not installed in this env, to install use pip install QUnfold[guro]\n\n") |
191 |
| - sys.exit(1) |
192 |
| - model.setParam("OutputFlag", 0) |
193 |
| - x = [ |
194 |
| - model.addVar(vtype=gurobipy.GRB.INTEGER, lb=0, ub=2**b - 1) |
195 |
| - for b in self.num_bits |
196 |
| - ] |
197 |
| - a = self.linear_coeffs |
198 |
| - B = self.quadratic_coeffs |
199 |
| - model.setObjective(a @ x + x @ B @ x, sense=gurobipy.GRB.MINIMIZE) |
200 |
| - model.optimize() |
201 |
| - sol = np.array([var.x for var in x]) |
202 |
| - err = np.sqrt(sol) |
203 |
| - cov = np.diag(sol) |
204 |
| - return sol, err, cov |
| 189 | + model.setParam("OutputFlag", 0) |
| 190 | + x = [ |
| 191 | + model.addVar(vtype=gurobipy.GRB.INTEGER, lb=0, ub=2**b - 1) |
| 192 | + for b in self.num_bits |
| 193 | + ] |
| 194 | + a = self.linear_coeffs |
| 195 | + B = self.quadratic_coeffs |
| 196 | + model.setObjective(a @ x + x @ B @ x, sense=gurobipy.GRB.MINIMIZE) |
| 197 | + model.optimize() |
| 198 | + sol = np.array([var.x for var in x]) |
| 199 | + err = np.sqrt(sol) |
| 200 | + cov = np.diag(sol) |
| 201 | + return sol, err, cov |
205 | 202 |
|
206 |
| - def solve_gurobi_binary(self): |
207 |
| - try: |
| 203 | + def solve_gurobi_binary(self): |
208 | 204 | model = gurobipy.Model()
|
209 |
| - except Exception as e: |
210 |
| - import traceback |
211 |
| - import sys |
212 |
| - print(traceback.format_exc()) |
213 |
| - print("\n\nGurobi not installed in this env, to install use pip install QUnfold[guro]\n\n") |
214 |
| - sys.exit(1) |
215 |
| - model.setParam("OutputFlag", 0) |
216 |
| - num_bits = self.num_bits |
217 |
| - x = [ |
218 |
| - model.addVar(vtype=gurobipy.GRB.BINARY) |
219 |
| - for i in range(self.num_bins) |
220 |
| - for _ in range(num_bits[i]) |
221 |
| - ] |
222 |
| - Q = self.qubo_matrix |
223 |
| - model.setObjective(x @ Q @ x, sense=gurobipy.GRB.MINIMIZE) |
224 |
| - model.optimize() |
225 |
| - bitstr = np.array([var.x for var in x], dtype=int) |
226 |
| - arrays = np.split(bitstr, np.cumsum(num_bits[:-1])) |
227 |
| - sol = np.array( |
228 |
| - [int("".join(arr.astype(str))[::-1], base=2) for arr in arrays], dtype=float |
229 |
| - ) |
230 |
| - err = np.sqrt(sol) |
231 |
| - cov = np.diag(sol) |
232 |
| - return sol, err, cov |
| 205 | + model.setParam("OutputFlag", 0) |
| 206 | + num_bits = self.num_bits |
| 207 | + x = [ |
| 208 | + model.addVar(vtype=gurobipy.GRB.BINARY) |
| 209 | + for i in range(self.num_bins) |
| 210 | + for _ in range(num_bits[i]) |
| 211 | + ] |
| 212 | + Q = self.qubo_matrix |
| 213 | + model.setObjective(x @ Q @ x, sense=gurobipy.GRB.MINIMIZE) |
| 214 | + model.optimize() |
| 215 | + bitstr = np.array([var.x for var in x], dtype=int) |
| 216 | + arrays = np.split(bitstr, np.cumsum(num_bits[:-1])) |
| 217 | + sol = np.array( |
| 218 | + [int("".join(arr.astype(str))[::-1], base=2) for arr in arrays], |
| 219 | + dtype=float, |
| 220 | + ) |
| 221 | + err = np.sqrt(sol) |
| 222 | + cov = np.diag(sol) |
| 223 | + return sol, err, cov |
0 commit comments