Skip to content

Commit

Permalink
More demo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd committed Nov 20, 2023
1 parent e0f91b7 commit 61b690f
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions python/demos/demo_periodic_geometrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@
NY = 100
mesh = create_unit_square(MPI.COMM_WORLD, NX, NY)
V = fem.functionspace(mesh, ("Lagrange", 1, (mesh.geometry.dim, )))
tol = 250 * np.finfo(default_scalar_type).resolution


def dirichletboundary(x):
return np.logical_or(np.isclose(x[1], 0), np.isclose(x[1], 1))
return np.logical_or(np.isclose(x[1], 0, atol=tol), np.isclose(x[1], 1, atol=tol))


# Create Dirichlet boundary condition
Expand All @@ -53,11 +54,11 @@ def dirichletboundary(x):


def periodic_boundary(x):
return np.isclose(x[0], 1)
return np.isclose(x[0], 1, atol=tol)


def periodic_relation(x):
out_x = np.zeros(x.shape)
out_x = np.zeros_like(x)
out_x[0] = 1 - x[0]
out_x[1] = x[1]
out_x[2] = x[2]
Expand Down Expand Up @@ -152,20 +153,21 @@ def periodic_relation(x):
with Timer("~Demo: Verification"):
dolfinx_mpc.utils.compare_mpc_lhs(A_org, problem._A, mpc, root=root)
dolfinx_mpc.utils.compare_mpc_rhs(L_org, problem._b, mpc, root=root)

is_complex = np.issubdtype(default_scalar_type, np.complexfloating) # type: ignore
scipy_dtype = np.complex128 if is_complex else np.float64
# Gather LHS, RHS and solution on one process
A_csr = dolfinx_mpc.utils.gather_PETScMatrix(A_org, root=root)
K = dolfinx_mpc.utils.gather_transformation_matrix(mpc, root=root)
L_np = dolfinx_mpc.utils.gather_PETScVector(L_org, root=root)
u_mpc = dolfinx_mpc.utils.gather_PETScVector(uh.vector, root=root)

if MPI.COMM_WORLD.rank == root:
KTAK = K.T * A_csr * K
reduced_L = K.T @ L_np
KTAK = K.T.astype(scipy_dtype) * A_csr.astype(scipy_dtype) * K.astype(scipy_dtype)
reduced_L = K.T.astype(scipy_dtype) @ L_np.astype(scipy_dtype)
# Solve linear system
d = scipy.sparse.linalg.spsolve(KTAK, reduced_L)
# Back substitution to full solution vector
uh_numpy = K @ d
assert np.allclose(uh_numpy, u_mpc, atol=float(np.finfo(u_mpc.dtype).resolution))
uh_numpy = K.astype(scipy_dtype) @ d.astype(scipy_dtype)
assert np.allclose(uh_numpy.astype(u_mpc.dtype), u_mpc, atol=tol)
list_timings(MPI.COMM_WORLD, [TimingType.wall])
L_org.destroy()

0 comments on commit 61b690f

Please sign in to comment.