Skip to content

Commit

Permalink
#5230 - Make LLM mention extraction more robust
Browse files Browse the repository at this point in the history
- Add support for another response format
  • Loading branch information
reckart committed Jan 15, 2025
1 parent 9baef09 commit f034918
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import static java.util.Collections.sort;
import static java.util.Comparator.comparing;
import static org.apache.commons.lang3.StringUtils.isBlank;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.apache.commons.lang3.StringUtils.normalizeSpace;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -40,6 +42,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.databind.node.TextNode;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.support.prompt.PromptContext;
Expand Down Expand Up @@ -115,6 +119,15 @@ Map<String, MentionsSample> generateSamples(RecommendationEngine aEngine, CAS aC
labelsSeen.addAll(sentenceAndLabels.getValue());
}

// Generate a fake sample to demonstrate how a result should look
// if (examples.isEmpty()) {
// var text = "This sentence contains a secret and a hint.";
// var sample = new MentionsSample(text);
// sample.addMention("secret", "mystery");
// sample.addMention("hint", "clue");
// examples.put(text, sample);
// }

return examples;
}

Expand Down Expand Up @@ -147,20 +160,51 @@ List<Pair<String, String>> extractMentionFromJson(String aResponse)
mentions.add(Pair.of(item.asText(), fieldEntry.getKey()));
}
if (item.isObject()) {
// Looks like this
// "politicians": [
// { "name": "President Livingston" },
// { "name": "John" },
// { "name": "Don Horny" }
// ]
var subFieldIterator = item.fields();
while (subFieldIterator.hasNext()) {
var subEntry = subFieldIterator.next();
if (subEntry.getValue().isTextual()) {
mentions.add(Pair.of(subEntry.getValue().asText(),
fieldEntry.getKey()));
var fields = toList(item.fields());
if (fields.size() == 1) {
// Looks like this
// "politicians": [
// { "name": "President Livingston" },
// { "name": "John" },
// { "name": "Don Horny" }
// ]
var nestedFieldIterator = item.fields();
while (nestedFieldIterator.hasNext()) {
var nestedEntry = nestedFieldIterator.next();
if (nestedEntry.getValue().isTextual()) {
mentions.add(Pair.of(nestedEntry.getValue().asText(),
fieldEntry.getKey()));
}
break;
}
}
else if (fields.size() >= 2) {
// Looks like this
// "politicians": [
// { "text": "President Livingston", "type"="politician" },
// { "text": "John", "type"="politician" },
// { "text": "Don Horny", "type"="politician" }
// ]
String text = null;
String label = null;
if (item.get("text") instanceof TextNode tn) {
text = tn.asText();
}
else if (item.get("name") instanceof TextNode tn) {
text = tn.asText();
}
if (item.get("type") instanceof TextNode tn) {
label = tn.asText();
}
else if (item.get("value") instanceof TextNode tn) {
label = tn.asText();
}
else if (item.get("label") instanceof TextNode tn) {
label = tn.asText();
}
if (isNotBlank(text) && isNotBlank(label)) {
mentions.add(Pair.of(text, label));
}
break;
}
}
}
Expand Down Expand Up @@ -198,6 +242,13 @@ List<Pair<String, String>> extractMentionFromJson(String aResponse)
return mentions;
}

private <T> List<T> toList(Iterator<T> aIterator)
{
var list = new ArrayList<T>();
aIterator.forEachRemaining(list::add);
return list;
}

private void mentionsToPredictions(RecommendationEngine aEngine, CAS aCas,
AnnotationFS aCandidate, List<Pair<String, String>> mentions)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import de.tudarmstadt.ukp.inception.recommendation.imls.llm.support.response.MentionsFromJsonExtractor;

class MentionsFromJsonExtractorTest
{
private MentionsFromJsonExtractor sut = new MentionsFromJsonExtractor();
Expand All @@ -33,10 +31,10 @@ class MentionsFromJsonExtractorTest
void testExtractMentionFromJson_categorizedNumbers()
{
var json = """
{
"even_numbers": [2, 4, 6]
}
""";
{
"even_numbers": [2, 4, 6]
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("2", "even_numbers"), //
Expand All @@ -48,12 +46,12 @@ void testExtractMentionFromJson_categorizedNumbers()
void testExtractMentionFromJson_nullLabels()
{
var json = """
{
"Honolulu": null,
"Columbia University": null,
"Harvard Law Review": null
}
""";
{
"Honolulu": null,
"Columbia University": null,
"Harvard Law Review": null
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("Honolulu", null), //
Expand All @@ -65,11 +63,11 @@ void testExtractMentionFromJson_nullLabels()
void testExtractMentionFromJson_categorizedStrings()
{
var json = """
{
"Person": ["John"],
"Location": ["diner", "Starbucks"]
}
""";
{
"Person": ["John"],
"Location": ["diner", "Starbucks"]
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("John", "Person"), //
Expand All @@ -81,14 +79,42 @@ void testExtractMentionFromJson_categorizedStrings()
void testExtractMentionFromJson_categorizedObjects()
{
var json = """
{
"politicians": [
{ "name": "President Livingston" },
{ "name": "John" },
{ "name": "Don Horny" }
]
}
""";
{
"politicians": [
{ "name": "President Livingston" },
{ "name": "John" },
{ "name": "Don Horny" }
]
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("President Livingston", "politicians"), //
Pair.of("John", "politicians"), //
Pair.of("Don Horny", "politicians"));
}

@Test
void testExtractMentionFromJson_structuredObjects()
{
var json = """
{
"entities": [
{
"type": "politicians",
"text": "President Livingston"
},
{
"type": "politicians",
"text": "John"
},
{
"type": "politicians",
"text": "Don Horny"
}
]
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("President Livingston", "politicians"), //
Expand All @@ -100,12 +126,12 @@ void testExtractMentionFromJson_categorizedObjects()
void testExtractMentionFromJson_namedObjects()
{
var json = """
{
"John": {"type": "PERSON"},
"diner": {"type": "LOCATION"},
"Starbucks": {"type": "LOCATION"}
}
""";
{
"John": {"type": "PERSON"},
"diner": {"type": "LOCATION"},
"Starbucks": {"type": "LOCATION"}
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("John", null), //
Expand All @@ -117,13 +143,13 @@ void testExtractMentionFromJson_namedObjects()
void testExtractMentionFromJson_keyValue()
{
var json = """
{
"John": "politician",
"President Livingston": "politician",
"minister of foreign affairs": "politician",
"Don Horny": "politician"
}
""";
{
"John": "politician",
"President Livingston": "politician",
"minister of foreign affairs": "politician",
"Don Horny": "politician"
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("John", "politician"), //
Expand All @@ -139,11 +165,11 @@ void testExtractMentionFromJson_valueKey()
// We assume that the first item is the most relevant one (the
// mention) so we do not get a bad mention in cases like this:
var json = """
{
"name": "Don Horny",
"affiliation": "Lord of Darkness"
}
""";
{
"name": "Don Horny",
"affiliation": "Lord of Darkness"
}
""";
assertThat(sut.extractMentionFromJson(json)) //
.containsExactly( //
Pair.of("Don Horny", null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static de.tudarmstadt.ukp.inception.websocket.config.WebSocketConstants.PARAM_PROJECT;

import java.io.IOException;
import java.security.Principal;
import java.util.List;
import java.util.Objects;
Expand All @@ -34,7 +33,6 @@
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;

Expand Down Expand Up @@ -64,7 +62,6 @@ public SchedulerWebsocketControllerImpl(SchedulingService aSchedulingService,

@SubscribeMapping(USER_TASKS_TOPIC)
public List<MTaskStateUpdate> onSubscribeToUserTaskUpdates(Principal user)
throws AccessDeniedException
{
return schedulingService.getAllTasks().stream() //
.filter(t -> t.getParentTask() == null) //
Expand All @@ -80,7 +77,6 @@ public List<MTaskStateUpdate> onSubscribeToUserTaskUpdates(Principal user)
public List<MTaskStateUpdate> onSubscribeToProjectTaskUpdates(
SimpMessageHeaderAccessor aHeaderAccessor, Principal aPrincipal, //
@DestinationVariable(PARAM_PROJECT) long aProjectId)
throws IOException
{
return schedulingService.getAllTasks().stream() //
.filter(t -> t.getParentTask() == null) //
Expand Down

0 comments on commit f034918

Please sign in to comment.