...
  View open merge request
Commits (3)
......@@ -74,7 +74,7 @@ public class IsolatedToolsTestMojo extends AbstractTestMojo {
runPipeline(jCas, engineDescription);
}
TestResult testResult = new TestResult(engineDescription);
TestResult testResult = new TestResult(engineDescription, getClass().getClassLoader());
measureAndAddToTestResult(testData, plainTestData, testResult, engineDescription);
return new TestResultWrapper(testResult, engineDescription);
}
......
......@@ -498,7 +498,7 @@ public class PipelinesTestMojo extends AbstractTestMojo {
}
private TestResult determineMeasures(Iterable<AnalysisEngineDescription> pipeline, List<JCas> testData, List<JCas> pipelineData) {
TestResult testResult = new TestResult(pipeline);
TestResult testResult = new TestResult(pipeline, getUrlClassLoader());
for (AnalysisEngineDescription analysisEngineDescription : pipeline) {
measureAndAddToTestResult(testData, pipelineData, testResult, analysisEngineDescription);
......
......@@ -35,16 +35,23 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import de.schrieveslaach.nlpf.plumbing.util.AnalysisEngineDescriptionNode;
import de.schrieveslaach.nlpf.plumbing.util.AnalysisEngineGraph;
import de.tudarmstadt.ukp.dkpro.core.eval.measure.FMeasure;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.jgrapht.graph.DefaultEdge;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
......@@ -53,21 +60,24 @@ import static java.util.Arrays.asList;
@JsonDeserialize(using = TestResult.TestResultDeserializer.class)
@JsonSerialize(using = TestResult.TestResultSerializer.class)
@EqualsAndHashCode
@EqualsAndHashCode(exclude = "aeGraph")
class TestResult implements Comparable<TestResult> {
private final SortedSet<Measure> measures = new TreeSet<>();
private final List<String> pipelineDescriptors = new ArrayList<>();
private final AnalysisEngineGraph aeGraph = new AnalysisEngineGraph();
private TestResult() {
}
TestResult(AnalysisEngineDescription engineDescription) {
this(asList(engineDescription));
TestResult(AnalysisEngineDescription engineDescription, ClassLoader classLoader) {
this(asList(engineDescription), classLoader);
}
TestResult(Iterable<AnalysisEngineDescription> pipeline) {
TestResult(Iterable<AnalysisEngineDescription> pipeline, ClassLoader classLoader) {
aeGraph.init(Lists.newArrayList(pipeline), classLoader);
for (AnalysisEngineDescription aed : pipeline) {
pipelineDescriptors.add(hash(aed) + ".xml");
}
......@@ -169,6 +179,8 @@ class TestResult implements Comparable<TestResult> {
public static class TestResultSerializer extends JsonSerializer<TestResult> {
private Map<AnalysisEngineDescriptionNode, Integer> aedNodeIndexMap = Maps.newHashMap();
@Override
public void serialize(TestResult testResult, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
jsonGenerator.writeStartObject();
......@@ -208,8 +220,52 @@ class TestResult implements Comparable<TestResult> {
}
jsonGenerator.writeEndArray();
writeNodes(jsonGenerator, testResult.aeGraph);
writeEdges(jsonGenerator, testResult.aeGraph);
jsonGenerator.writeEndObject();
}
private void writeNodes(JsonGenerator jsonGenerator, AnalysisEngineGraph aeGraph) throws IOException {
int i = 1;
Iterator<AnalysisEngineDescriptionNode> iterator = aeGraph.stream().iterator();
jsonGenerator.writeFieldName("nodes");
jsonGenerator.writeStartArray();
while (iterator.hasNext()) {
AnalysisEngineDescriptionNode aedNode = iterator.next();
jsonGenerator.writeStartObject();
jsonGenerator.writeNumberField("id", i);
jsonGenerator.writeStringField("label", aedNode.toString());
jsonGenerator.writeEndObject();
aedNodeIndexMap.put(aedNode, i);
i++;
}
jsonGenerator.writeEndArray();
}
private void writeEdges(JsonGenerator jsonGenerator, AnalysisEngineGraph aeGraph) throws IOException {
jsonGenerator.writeFieldName("edges");
jsonGenerator.writeStartArray();
for (DefaultEdge defaultEdge : aeGraph.edgeSet()) {
AnalysisEngineDescriptionNode edgeSource = aeGraph.getEdgeSource(defaultEdge);
AnalysisEngineDescriptionNode edgeTarget = aeGraph.getEdgeTarget(defaultEdge);
int sourceIndex = aedNodeIndexMap.get(edgeSource);
int targetIndex = aedNodeIndexMap.get(edgeTarget);
jsonGenerator.writeStartObject();
jsonGenerator.writeNumberField("from", sourceIndex);
jsonGenerator.writeNumberField("to", targetIndex);
jsonGenerator.writeEndObject();
}
jsonGenerator.writeEndArray();
}
}
public static class TestResultDeserializer extends JsonDeserializer<TestResult> {
......
......@@ -36,8 +36,10 @@ import org.mockito.junit.MockitoJUnitRunner;
import java.io.File;
import java.nio.charset.Charset;
import java.util.Collections;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
import static org.hamcrest.Matchers.is;
......@@ -66,6 +68,12 @@ public class IsolatedToolsTestMojoTest extends AbstractTestMojoTest {
.put("precision", 0.352)
.put("recall", 0.352))
))
.put("nodes", asList(
new JSONObject()
.put("id", 1)
.put("label", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger")
))
.put("edges", emptyList())
.toString();
}
......@@ -85,6 +93,12 @@ public class IsolatedToolsTestMojoTest extends AbstractTestMojoTest {
.put("precision", 0.75)
.put("recall", 0.375))
))
.put("nodes", asList(
new JSONObject()
.put("id", 1)
.put("label", "OpenNLP Named Entity Recognizer (variant = person)")
))
.put("edges", emptyList())
.toString();
}
......
......@@ -97,6 +97,18 @@ public class PipelinesTestMojoPipelinesTest extends AbstractTestMojoTest {
.put("f-measure", 0.8057)
.put("precision", 0.8615)
.put("recall", 0.7567))
))
.put("nodes", asList(new JSONObject()
.put("id", 1)
.put("label", "MySegmenter"),
new JSONObject()
.put("id", 2)
.put("label", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger")
))
.put("edges", asList(new JSONObject()
.put("from", 1)
.put("to", 2)
))
.toString();
}
......@@ -145,6 +157,24 @@ public class PipelinesTestMojoPipelinesTest extends AbstractTestMojoTest {
.put("precision", 1.0)
.put("recall", 1.0))
))
.put("nodes", asList(
new JSONObject()
.put("id", 1)
.put("label", "OpenNLP Segmenter"),
new JSONObject()
.put("id", 2)
.put("label", "OpenNLP Named Entity Recognizer (variant = person)"),
new JSONObject()
.put("id", 3)
.put("label", "OpenNLP Named Entity Recognizer (variant = organization)")
))
.put("edges", asList(new JSONObject()
.put("from", 1)
.put("to", 2),
new JSONObject()
.put("from", 1)
.put("to", 3)
))
.toString();
}
......
......@@ -44,6 +44,7 @@ import java.util.List;
import java.util.Map;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static net.javacrumbs.jsonunit.JsonMatchers.jsonEquals;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
import static org.hamcrest.Matchers.contains;
......@@ -74,6 +75,61 @@ public class TestResultTest {
.put("precision", 0.5)
.put("recall", 0.5))
))
.put("nodes", asList(new JSONObject()
.put("id", 1)
.put("label", "MySegmenter")
, new JSONObject()
.put("id", 2)
.put("label", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger")
))
.put("edges", asList(new JSONObject()
.put("from", 1)
.put("to", 2)
))
.toString();
}
public static String jsonTestResultWithNamedEntityRecognizer() {
return new JSONObject()
.put("descriptors", asList(
MySegmenter.DEFAULT_HASH_CODE + ".xml",
MyNamedEntityRecognizer.DEFAULT_HASH_CODE + ".xml",
MyPosTagger.DEFAULT_HASH_CODE + ".xml"
))
.put("measures", asList(new JSONObject()
.put("analysisEngineName", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger")
.put("outputType", "de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS")
.put("variant", JSONObject.NULL)
.put("measure", new JSONObject()
.put("f-measure", 0.5)
.put("precision", 0.5)
.put("recall", 0.5))
, new JSONObject()
.put("analysisEngineName", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyNamedEntityRecognizer")
.put("outputType", "de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity")
.put("variant", "person")
.put("measure", new JSONObject()
.put("f-measure", 0.5)
.put("precision", 0.5)
.put("recall", 0.5))
))
.put("nodes", asList(new JSONObject()
.put("id", 1)
.put("label", "MySegmenter")
, new JSONObject()
.put("id", 2)
.put("label", "MyNamedEntityRecognizer (variant = person)")
, new JSONObject()
.put("id", 3)
.put("label", "de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger")
))
.put("edges", asList(new JSONObject()
.put("from", 1)
.put("to", 2),
new JSONObject()
.put("from", 1)
.put("to", 3)
))
.toString();
}
......@@ -111,6 +167,16 @@ public class TestResultTest {
assertThat(json, jsonEquals(jsonTestResult()));
}
@Test
public void shouldSerializeAsJson_WithVariant() throws Exception {
TestResult testResult = createTestResultWithNamedEntityRecognizer();
ObjectMapper mapper = new ObjectMapper();
String json = mapper.writeValueAsString(testResult);
assertThat(json, jsonEquals(jsonTestResultWithNamedEntityRecognizer()));
}
@Test
public void shouldComplyWithHashCodeContract_ShouldContainTestResultWithDifferentFMeasureValue() throws Exception {
Map<TestResult, String> map = createTestResultStringMap();
......@@ -323,7 +389,7 @@ public class TestResultTest {
}
private TestResult createTestResult(List<AnalysisEngineDescription> pipeline, double fScore, String outputType, String variant) throws ResourceInitializationException {
TestResult anotherTestResult = new TestResult(pipeline);
TestResult anotherTestResult = new TestResult(pipeline, getClass().getClassLoader());
FMeasure fMeasure = mock(FMeasure.class);
when(fMeasure.getFMeasure()).thenReturn(fScore);
when(fMeasure.getPrecision()).thenReturn(fScore);
......@@ -332,6 +398,33 @@ public class TestResultTest {
return anotherTestResult;
}
private TestResult createTestResultWithNamedEntityRecognizer() throws ResourceInitializationException {
List<AnalysisEngineDescription> pipeline = asList(
createEngineDescription(MySegmenter.class),
createEngineDescription(MyNamedEntityRecognizer.class),
createEngineDescription(MyPosTagger.class));
double fScore = 0.5;
TestResult testResultWithNER = new TestResult(pipeline, getClass().getClassLoader());
FMeasure fMeasure = mock(FMeasure.class);
when(fMeasure.getFMeasure()).thenReturn(fScore);
when(fMeasure.getPrecision()).thenReturn(fScore);
when(fMeasure.getRecall()).thenReturn(fScore);
testResultWithNER.add("de.schrieveslaach.nlpf.maven.plugin.annotators.MyPosTagger",
"de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS",
null,
fMeasure
);
testResultWithNER.add("de.schrieveslaach.nlpf.maven.plugin.annotators.MyNamedEntityRecognizer",
"de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity",
"person",
fMeasure
);
return testResultWithNER;
}
private Map<TestResult, String> createTestResultStringMap() throws ResourceInitializationException {
TestResult testResult = createTestResult(
asList(
......
......@@ -38,6 +38,7 @@ import org.apache.uima.jcas.JCas;
@ResourceMetaData(name = "MyNamedEntityRecognizer")
public class MyNamedEntityRecognizer extends JCasAnnotator_ImplBase {
public static final String DEFAULT_HASH_CODE = "4C7420B966C09DB213E219FB05DDD257FB585B15812EA02B770AA41015F2C31F7C6F714BD9DAE0945505DE22BB4F474AD5A434698DA13B532ECEED2E95D43692";
@Getter
private static int runs = 0;
......
......@@ -158,5 +158,4 @@ public class AnalysisEngineGraph extends DefaultDirectedGraph<AnalysisEngineDesc
return descriptions;
}
}