diff --git a/pom.xml b/pom.xml index 1335df7..26f64dd 100644 --- a/pom.xml +++ b/pom.xml @@ -82,6 +82,11 @@ dropwizard-websockets 1.3.14 + + io.dropwizard-bundles + dropwizard-redirect-bundle + 1.3.5 + diff --git a/roman.yaml b/roman.yaml index da88622..f838182 100644 --- a/roman.yaml +++ b/roman.yaml @@ -16,7 +16,7 @@ logging: swagger: resourcePackage: com.wire.bots.roman.resources schemes: - - http +# - http - https jerseyClient: diff --git a/src/main/java/com/wire/bots/roman/Application.java b/src/main/java/com/wire/bots/roman/Application.java index ed8214c..b00a5b6 100644 --- a/src/main/java/com/wire/bots/roman/Application.java +++ b/src/main/java/com/wire/bots/roman/Application.java @@ -29,6 +29,8 @@ import com.wire.xenon.MessageHandlerBase; import com.wire.xenon.factories.CryptoFactory; import com.wire.xenon.factories.StorageFactory; +import io.dropwizard.bundles.redirect.PathRedirect; +import io.dropwizard.bundles.redirect.RedirectBundle; import io.dropwizard.setup.Bootstrap; import io.dropwizard.setup.Environment; import io.dropwizard.websockets.WebsocketBundle; @@ -61,6 +63,7 @@ public void initialize(Bootstrap bootstrap) { bootstrap.addBundle(new WebsocketBundle(WebSocket.class)); bootstrap.addCommand(new UpdateCertCommand()); + bootstrap.addBundle(new RedirectBundle(new PathRedirect("/", "/swagger#/default"))); } @Override diff --git a/src/main/java/com/wire/bots/roman/DAO/BotsDAO.java b/src/main/java/com/wire/bots/roman/DAO/BotsDAO.java index e287543..9c753a3 100644 --- a/src/main/java/com/wire/bots/roman/DAO/BotsDAO.java +++ b/src/main/java/com/wire/bots/roman/DAO/BotsDAO.java @@ -1,6 +1,6 @@ package com.wire.bots.roman.DAO; -import com.wire.bots.roman.DAO.mappers.BotsMapper; +import com.wire.bots.roman.DAO.mappers.UUIDMapper; import org.jdbi.v3.sqlobject.config.RegisterColumnMapper; import org.jdbi.v3.sqlobject.customizer.Bind; import org.jdbi.v3.sqlobject.statement.SqlQuery; @@ -15,12 +15,12 @@ int insert(@Bind("bot") UUID bot, @Bind("provider") UUID provider); @SqlQuery("SELECT provider AS uuid FROM Bots WHERE id = :bot") - @RegisterColumnMapper(BotsMapper.class) + @RegisterColumnMapper(UUIDMapper.class) UUID getProviderId(@Bind("bot") UUID bot); @SqlQuery("SELECT id AS uuid FROM Bots WHERE provider = :providerId") - @RegisterColumnMapper(BotsMapper.class) + @RegisterColumnMapper(UUIDMapper.class) List getBotIds(@Bind("providerId") UUID providerId); @SqlUpdate("DELETE FROM Bots WHERE id = :botId") diff --git a/src/main/java/com/wire/bots/roman/DAO/BroadcastDAO.java b/src/main/java/com/wire/bots/roman/DAO/BroadcastDAO.java new file mode 100644 index 0000000..50cb2cb --- /dev/null +++ b/src/main/java/com/wire/bots/roman/DAO/BroadcastDAO.java @@ -0,0 +1,70 @@ +package com.wire.bots.roman.DAO; + +import com.wire.bots.roman.DAO.mappers.UUIDMapper; +import org.jdbi.v3.core.mapper.ColumnMapper; +import org.jdbi.v3.core.statement.StatementContext; +import org.jdbi.v3.sqlobject.config.RegisterColumnMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import javax.annotation.Nullable; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; +import java.util.UUID; + +public interface BroadcastDAO { + @SqlUpdate("INSERT INTO Broadcast (broadcast_id, bot_id, provider, message_status, message_id) " + + "VALUES (:broadcastId, :botId, :provider, :status, :messageId) " + + "ON CONFLICT(broadcast_id, message_id, message_status) DO NOTHING") + int insert(@Bind("broadcastId") UUID broadcastId, + @Bind("botId") UUID botId, + @Bind("provider") UUID provider, + @Bind("messageId") UUID messageId, + @Bind("status") int status); + + @SqlUpdate("INSERT INTO Broadcast (broadcast_id, bot_id, provider, message_status, message_id) " + + "(SELECT B.broadcast_id, B.bot_id, B.provider, :status, B.message_id " + + "FROM Broadcast AS B " + + "WHERE message_id = :messageId FETCH FIRST ROW ONLY)") + int insertStatus(@Bind("messageId") UUID messageId, @Bind("status") int status); + + @SqlQuery("SELECT message_status, count(*) AS count " + + "FROM Broadcast " + + "WHERE broadcast_id = :broadcastId " + + "GROUP BY message_status") + @RegisterColumnMapper(ReportMapper.class) + List report(@Bind("broadcastId") UUID broadcastId); + + @SqlQuery("SELECT broadcast_id AS uuid " + + "FROM Broadcast " + + "WHERE provider = :provider " + + "ORDER BY created DESC " + + "FETCH FIRST ROW ONLY") + @RegisterColumnMapper(UUIDMapper.class) + @Nullable + UUID getBroadcastId(@Bind("provider") UUID provider); + + enum Type { + SENT, + DELIVERED, + READ, + FAILED + } + + class Pair { + public Type type; + public int count; + } + + class ReportMapper implements ColumnMapper { + @Override + public Pair map(ResultSet rs, int columnNumber, StatementContext ctx) throws SQLException { + Pair ret = new Pair(); + ret.type = Type.values()[rs.getInt("message_status")]; + ret.count = rs.getInt("count"); + return ret; + } + } +} diff --git a/src/main/java/com/wire/bots/roman/DAO/mappers/BotsMapper.java b/src/main/java/com/wire/bots/roman/DAO/mappers/UUIDMapper.java similarity index 86% rename from src/main/java/com/wire/bots/roman/DAO/mappers/BotsMapper.java rename to src/main/java/com/wire/bots/roman/DAO/mappers/UUIDMapper.java index cf4d132..550a2a6 100644 --- a/src/main/java/com/wire/bots/roman/DAO/mappers/BotsMapper.java +++ b/src/main/java/com/wire/bots/roman/DAO/mappers/UUIDMapper.java @@ -9,14 +9,14 @@ import java.sql.SQLException; import java.util.UUID; -public class BotsMapper implements ColumnMapper { +public class UUIDMapper implements ColumnMapper { @Override @Nullable public UUID map(ResultSet rs, int columnNumber, StatementContext ctx) { try { return getUuid(rs, "uuid"); } catch (SQLException e) { - Logger.error("BotsMapper: i: %d, e: %s", columnNumber, e); + Logger.error("UUIDMapper: i: %d, e: %s", columnNumber, e); return null; } } diff --git a/src/main/java/com/wire/bots/roman/MessageHandler.java b/src/main/java/com/wire/bots/roman/MessageHandler.java index b03a231..45dd3e9 100644 --- a/src/main/java/com/wire/bots/roman/MessageHandler.java +++ b/src/main/java/com/wire/bots/roman/MessageHandler.java @@ -3,6 +3,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.waz.model.Messages; import com.wire.bots.roman.DAO.BotsDAO; +import com.wire.bots.roman.DAO.BroadcastDAO; import com.wire.bots.roman.DAO.ProvidersDAO; import com.wire.bots.roman.model.IncomingMessage; import com.wire.bots.roman.model.OutgoingMessage; @@ -29,6 +30,8 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; +import static com.wire.bots.roman.DAO.BroadcastDAO.Type.DELIVERED; +import static com.wire.bots.roman.DAO.BroadcastDAO.Type.READ; import static com.wire.bots.roman.Tools.generateToken; public class MessageHandler extends MessageHandlerBase { @@ -37,12 +40,14 @@ public class MessageHandler extends MessageHandlerBase { private final Client jerseyClient; private final ProvidersDAO providersDAO; private final BotsDAO botsDAO; + private final BroadcastDAO broadcastDAO; private Sender sender; MessageHandler(Jdbi jdbi, Client jerseyClient) { this.jerseyClient = jerseyClient; providersDAO = jdbi.onDemand(ProvidersDAO.class); botsDAO = jdbi.onDemand(BotsDAO.class); + broadcastDAO = jdbi.onDemand(BroadcastDAO.class); } @Override @@ -183,6 +188,18 @@ public void onEvent(WireClient client, UUID userId, Messages.GenericMessage even } } + @Override + public void onConfirmation(WireClient client, ConfirmationMessage msg) { + try { + final UUID messageId = msg.getConfirmationMessageId(); + final ConfirmationMessage.Type type = msg.getType(); + + broadcastDAO.insertStatus(messageId, type == ConfirmationMessage.Type.DELIVERED ? DELIVERED.ordinal() : READ.ordinal()); + } catch (Exception e) { + Logger.error("onConfirmation: %s %s", client.getId(), e); + } + } + private void onComposite(UUID botId, UUID userId, Messages.GenericMessage event) { final Messages.Composite composite = event.getComposite(); final UUID messageId = UUID.fromString(event.getMessageId()); diff --git a/src/main/java/com/wire/bots/roman/Sender.java b/src/main/java/com/wire/bots/roman/Sender.java index 60868f3..88550bd 100644 --- a/src/main/java/com/wire/bots/roman/Sender.java +++ b/src/main/java/com/wire/bots/roman/Sender.java @@ -24,20 +24,21 @@ public Sender(ClientRepo repo) { this.repo = repo; } + @Nullable public UUID send(IncomingMessage message, UUID botId) throws Exception { try (WireClient client = repo.getClient(botId)) { + if (client == null) + return null; + return send(message, client); } } - @Nullable - private UUID send(IncomingMessage message, @Nullable WireClient client) throws Exception { - if (client == null) - return null; - + private UUID send(IncomingMessage message, WireClient client) throws Exception { switch (message.type) { case "text": { MessageText text = new MessageText(message.text.data); + text.setExpectsReadConfirmation(true); if (message.text.mentions != null) { for (Mention mention : message.text.mentions) text.addMention(mention.userId, mention.offset, mention.length); diff --git a/src/main/java/com/wire/bots/roman/model/Report.java b/src/main/java/com/wire/bots/roman/model/Report.java new file mode 100644 index 0000000..a4d713e --- /dev/null +++ b/src/main/java/com/wire/bots/roman/model/Report.java @@ -0,0 +1,18 @@ +package com.wire.bots.roman.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.wire.bots.roman.DAO.BroadcastDAO; + +import javax.validation.constraints.NotNull; +import java.util.List; +import java.util.UUID; + +public class Report { + @JsonProperty + @NotNull + public UUID broadcastId; + + @JsonProperty + @NotNull + public List report; +} diff --git a/src/main/java/com/wire/bots/roman/resources/BroadcastResource.java b/src/main/java/com/wire/bots/roman/resources/BroadcastResource.java index 607158a..061c692 100644 --- a/src/main/java/com/wire/bots/roman/resources/BroadcastResource.java +++ b/src/main/java/com/wire/bots/roman/resources/BroadcastResource.java @@ -1,26 +1,24 @@ package com.wire.bots.roman.resources; import com.codahale.metrics.annotation.Metered; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.wire.bots.roman.DAO.BotsDAO; -import com.wire.bots.roman.DAO.ProvidersDAO; +import com.wire.bots.roman.DAO.BroadcastDAO; import com.wire.bots.roman.Sender; import com.wire.bots.roman.filters.ServiceTokenAuthorization; import com.wire.bots.roman.model.IncomingMessage; -import com.wire.bots.roman.model.Provider; +import com.wire.bots.roman.model.Report; import com.wire.xenon.backend.models.ErrorMessage; import com.wire.xenon.exceptions.MissingStateException; import com.wire.xenon.tools.Logger; -import io.jsonwebtoken.JwtException; import io.swagger.annotations.*; import org.jdbi.v3.core.Jdbi; +import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; +import javax.ws.rs.*; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; @@ -36,20 +34,20 @@ @Path("/broadcast") @Produces(MediaType.APPLICATION_JSON) public class BroadcastResource { - private final Jdbi jdbi; private final Sender sender; + private final BotsDAO botsDAO; + private final BroadcastDAO broadcastDAO; public BroadcastResource(Jdbi jdbi, Sender sender) { - this.jdbi = jdbi; this.sender = sender; + + botsDAO = jdbi.onDemand(BotsDAO.class); + broadcastDAO = jdbi.onDemand(BroadcastDAO.class); } @POST @ApiOperation(value = "Broadcast message on Wire", authorizations = {@Authorization(value = "Bearer")}) - @ApiResponses(value = { - @ApiResponse(code = 403, message = "Not authenticated"), - @ApiResponse(code = 404, message = "Unknown access token") - }) + @ApiResponses(value = {@ApiResponse(code = 403, message = "Not authenticated", response = Report.class)}) @Metered @ServiceTokenAuthorization public Response post(@Context ContainerRequestContext context, @@ -60,38 +58,68 @@ public Response post(@Context ContainerRequestContext context, UUID providerId = (UUID) context.getProperty(PROVIDER_ID); - ProvidersDAO providersDAO = jdbi.onDemand(ProvidersDAO.class); - Provider provider = providersDAO.get(providerId); - if (provider == null) { - return Response. - ok(new ErrorMessage("Unknown access token")). - status(404). - build(); - } - - Logger.info("BroadcastResource: `%s` provider: %s", message.type, providerId); - - BotsDAO botsDAO = jdbi.onDemand(BotsDAO.class); + Logger.info("BroadcastResource.post: `%s` provider: %s", message.type, providerId); List botIds = botsDAO.getBotIds(providerId); - int ret = 0; + final UUID broadcastId = UUID.randomUUID(); + for (UUID botId : botIds) { - if (send(botId, message)) - ret++; + final UUID messageId = send(botId, message); + if (messageId != null) { + broadcastDAO.insert(broadcastId, botId, providerId, messageId, BroadcastDAO.Type.SENT.ordinal()); + } } + Report ret = new Report(); + ret.broadcastId = broadcastId; + ret.report = broadcastDAO.report(broadcastId); + return Response. - ok(new ErrorMessage(String.format("%d messages sent", ret))). + ok(ret). build(); - } catch (JwtException e) { - Logger.warning("BroadcastResource %s", e); + } catch (Exception e) { + Logger.error("BroadcastResource.post: %s", e); + e.printStackTrace(); + return Response + .ok(new ErrorMessage(e.getMessage())) + .status(500) + .build(); + } + } + + @GET + @ApiOperation(value = "Get latest broadcast report", authorizations = {@Authorization(value = "Bearer")}) + @ApiResponses(value = {@ApiResponse(code = 404, message = "Unknown broadcastId", response = Report.class)}) + @Metered + @ServiceTokenAuthorization + public Response get(@Context ContainerRequestContext context, + @ApiParam @HeaderParam(APP_KEY) String token, + @ApiParam @QueryParam("id") UUID broadcastId) { + try { + final UUID providerId = (UUID) context.getProperty(PROVIDER_ID); + + if (broadcastId == null) { + broadcastId = broadcastDAO.getBroadcastId(providerId); + } + + if (broadcastId == null) { + return Response. + status(404). + build(); + } + + Logger.info("BroadcastResource.get: broadcast: %s provider: %s", broadcastId, providerId); + + Report ret = new Report(); + ret.broadcastId = broadcastId; + ret.report = broadcastDAO.report(broadcastId); + return Response. - ok(new ErrorMessage("Invalid Authorization token")). - status(403). + ok(ret). build(); } catch (Exception e) { - Logger.error("BroadcastResource: %s", e); + Logger.error("BroadcastResource.get: %s", e); e.printStackTrace(); return Response .ok(new ErrorMessage(e.getMessage())) @@ -100,29 +128,24 @@ public Response post(@Context ContainerRequestContext context, } } - private boolean send(UUID botId, IncomingMessage message) { + @Nullable + private UUID send(UUID botId, IncomingMessage message) { try { - final UUID messageId = sender.send(message, botId); - return messageId != null; + return sender.send(message, botId); } catch (MissingStateException e) { - Logger.warning("BroadcastResource: bot: %s, e: %s", botId, e); - jdbi.onDemand(BotsDAO.class).remove(botId); + Logger.warning("BroadcastResource.send: bot: %s, e: %s", botId, e); + botsDAO.remove(botId); } catch (Exception e) { e.printStackTrace(); - Logger.warning("BroadcastResource: bot: %s, e: %s", botId, e); + Logger.error("BroadcastResource.send: bot: %s, e: %s", botId, e); } - - return false; + return null; } - private void trace(IncomingMessage message) { - try { - if (Logger.getLevel() == Level.FINE) { - ObjectMapper objectMapper = new ObjectMapper(); - Logger.debug(objectMapper.writeValueAsString(message)); - } - } catch (Exception ignore) { - + private void trace(IncomingMessage message) throws JsonProcessingException { + if (Logger.getLevel() == Level.FINE) { + ObjectMapper objectMapper = new ObjectMapper(); + Logger.debug(objectMapper.writeValueAsString(message)); } } } \ No newline at end of file diff --git a/src/main/resources/db/migration/V102__broadcast.sql b/src/main/resources/db/migration/V102__broadcast.sql new file mode 100644 index 0000000..a4dda1b --- /dev/null +++ b/src/main/resources/db/migration/V102__broadcast.sql @@ -0,0 +1,9 @@ +CREATE TABLE Broadcast ( + broadcast_id UUID NOT NULL, + bot_id UUID NOT NULL, + provider UUID NOT NULL, + message_id UUID NOT NULL, + message_status INTEGER NOT NULL, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(broadcast_id, message_id, message_status) +); \ No newline at end of file diff --git a/src/test/java/com/wire/bots/roman/integrations/DatabaseTest.java b/src/test/java/com/wire/bots/roman/integrations/DatabaseTest.java index e86ad68..9bb848a 100644 --- a/src/test/java/com/wire/bots/roman/integrations/DatabaseTest.java +++ b/src/test/java/com/wire/bots/roman/integrations/DatabaseTest.java @@ -1,6 +1,7 @@ package com.wire.bots.roman.integrations; import com.wire.bots.roman.Application; +import com.wire.bots.roman.DAO.BroadcastDAO; import com.wire.bots.roman.DAO.ProvidersDAO; import com.wire.bots.roman.model.Config; import com.wire.bots.roman.model.Provider; @@ -11,6 +12,7 @@ import org.junit.Before; import org.junit.Test; +import java.util.List; import java.util.UUID; public class DatabaseTest { @@ -36,20 +38,20 @@ public void afterClass() { public void testProviderDAO() { final ProvidersDAO providersDAO = jdbi.onDemand(ProvidersDAO.class); - final UUID id = UUID.randomUUID(); + final UUID providerId = UUID.randomUUID(); final String name = "name"; final String email = "email@wire.com"; final String hash = "hash"; final String password = "password"; - final int insert = providersDAO.insert(name, id, email, hash, password); + final int insert = providersDAO.insert(name, providerId, email, hash, password); assert insert == 1; - Provider provider = providersDAO.get(id); + Provider provider = providersDAO.get(providerId); assert provider != null; assert provider.name.equals(name); assert provider.hash.equals(hash); assert provider.password.equals(password); - assert provider.id.equals(id); + assert provider.id.equals(providerId); assert provider.email.equals(email); provider = providersDAO.get(email); @@ -57,14 +59,14 @@ public void testProviderDAO() { assert provider.name.equals(name); assert provider.hash.equals(hash); assert provider.password.equals(password); - assert provider.id.equals(id); + assert provider.id.equals(providerId); assert provider.email.equals(email); final String url = "url"; final String auth = "auth"; final UUID serviceId = UUID.randomUUID(); final String service_name = "service name"; - int update = providersDAO.update(id, url, auth, serviceId, service_name); + int update = providersDAO.update(providerId, url, auth, serviceId, service_name); assert update == 1; provider = providersDAO.getByAuth(auth); @@ -75,19 +77,55 @@ public void testProviderDAO() { assert provider.serviceName.equals(service_name); final String newURL = "newURL"; - update = providersDAO.updateUrl(id, newURL); + update = providersDAO.updateUrl(providerId, newURL); assert update == 1; - provider = providersDAO.get(id); + provider = providersDAO.get(providerId); assert provider != null; assert provider.serviceUrl.equals(newURL); final String newName = "new service name"; - update = providersDAO.updateServiceName(id, newName); + update = providersDAO.updateServiceName(providerId, newName); assert update == 1; - provider = providersDAO.get(id); + provider = providersDAO.get(providerId); assert provider != null; assert provider.serviceName.equals(newName); } + + @Test + public void testBroadcastDAO() { + final BroadcastDAO broadcastDAO = jdbi.onDemand(BroadcastDAO.class); + + final UUID providerId = UUID.randomUUID(); + final UUID broadcastId = UUID.randomUUID(); + final UUID botId = UUID.randomUUID(); + final UUID messageId = UUID.randomUUID(); + + final int insert1 = broadcastDAO.insert(broadcastId, botId, providerId, messageId, 0); + assert insert1 == 1; + + int insertStatus = broadcastDAO.insertStatus(messageId, 1); + assert insertStatus == 1; + insertStatus = broadcastDAO.insertStatus(messageId, 2); + assert insertStatus == 1; + insertStatus = broadcastDAO.insertStatus(messageId, 3); + assert insertStatus == 1; + + final UUID get = broadcastDAO.getBroadcastId(providerId); + assert get != null; + assert get.equals(broadcastId); + + final List report = broadcastDAO.report(broadcastId); + + final UUID broadcastId2 = UUID.randomUUID(); + final UUID botId2 = UUID.randomUUID(); + final UUID messageId2 = UUID.randomUUID(); + final int insert2 = broadcastDAO.insert(broadcastId2, botId2, providerId, messageId2, 0); + assert insert2 == 1; + + final UUID get2 = broadcastDAO.getBroadcastId(providerId); + assert get2 != null; + assert get2.equals(broadcastId2); + } }