diff --git a/python/demos/demo_periodic_geometrical.py b/python/demos/demo_periodic_geometrical.py index d283c109..4560f635 100644 --- a/python/demos/demo_periodic_geometrical.py +++ b/python/demos/demo_periodic_geometrical.py @@ -38,10 +38,11 @@ NY = 100 mesh = create_unit_square(MPI.COMM_WORLD, NX, NY) V = fem.functionspace(mesh, ("Lagrange", 1, (mesh.geometry.dim, ))) +tol = 250 * np.finfo(default_scalar_type).resolution def dirichletboundary(x): - return np.logical_or(np.isclose(x[1], 0), np.isclose(x[1], 1)) + return np.logical_or(np.isclose(x[1], 0, atol=tol), np.isclose(x[1], 1, atol=tol)) # Create Dirichlet boundary condition @@ -53,11 +54,11 @@ def dirichletboundary(x): def periodic_boundary(x): - return np.isclose(x[0], 1) + return np.isclose(x[0], 1, atol=tol) def periodic_relation(x): - out_x = np.zeros(x.shape) + out_x = np.zeros_like(x) out_x[0] = 1 - x[0] out_x[1] = x[1] out_x[2] = x[2] @@ -152,7 +153,8 @@ def periodic_relation(x): with Timer("~Demo: Verification"): dolfinx_mpc.utils.compare_mpc_lhs(A_org, problem._A, mpc, root=root) dolfinx_mpc.utils.compare_mpc_rhs(L_org, problem._b, mpc, root=root) - + is_complex = np.issubdtype(default_scalar_type, np.complexfloating) # type: ignore + scipy_dtype = np.complex128 if is_complex else np.float64 # Gather LHS, RHS and solution on one process A_csr = dolfinx_mpc.utils.gather_PETScMatrix(A_org, root=root) K = dolfinx_mpc.utils.gather_transformation_matrix(mpc, root=root) @@ -160,12 +162,12 @@ def periodic_relation(x): u_mpc = dolfinx_mpc.utils.gather_PETScVector(uh.vector, root=root) if MPI.COMM_WORLD.rank == root: - KTAK = K.T * A_csr * K - reduced_L = K.T @ L_np + KTAK = K.T.astype(scipy_dtype) * A_csr.astype(scipy_dtype) * K.astype(scipy_dtype) + reduced_L = K.T.astype(scipy_dtype) @ L_np.astype(scipy_dtype) # Solve linear system d = scipy.sparse.linalg.spsolve(KTAK, reduced_L) # Back substitution to full solution vector - uh_numpy = K @ d - assert np.allclose(uh_numpy, u_mpc, atol=float(np.finfo(u_mpc.dtype).resolution)) + uh_numpy = K.astype(scipy_dtype) @ d.astype(scipy_dtype) + assert np.allclose(uh_numpy.astype(u_mpc.dtype), u_mpc, atol=tol) list_timings(MPI.COMM_WORLD, [TimingType.wall]) L_org.destroy()