Commit 6bc5425a authored by WUUUGI's avatar WUUUGI

Added check for asteriks in excluded trainer configuration.

parent dc097e84
Pipeline #42557971 failed with stages
in 28 minutes and 55 seconds
......@@ -23,12 +23,15 @@ package de.schrieveslaach.nlpf.maven.plugin;
*/
import org.apache.maven.plugin.testing.MojoRule;
import org.apache.maven.project.MavenProject;
import org.junit.Rule;
import org.junit.Test;
import java.io.File;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.io.FileMatchers.anExistingDirectory;
import static org.hamcrest.io.FileMatchers.anExistingFile;
import static org.junit.Assert.assertThat;
......@@ -56,4 +59,14 @@ public class TrainMojoIT extends BaseMojoIT {
assertThat(modelFile, is(anExistingFile()));
}
@Test
public void shouldNotTrain_TrainerExcludedInPom() throws Exception {
File testProjectBaseDir = getTestProjectBaseDir("/sample-project-with-excluded-trainer");
storeExampleCasFiles(testProjectBaseDir);
rule.executeMojo(testProjectBaseDir, "train");
File modelFile = new File(testProjectBaseDir, "target/models/de.company/domain-specific-corpus/de.tudarmstadt.ukp.dkpro.core.stanfordnlp-gpl/");
assertThat(modelFile, is(not(anExistingDirectory())));
}
}
<!--
========================LICENSE_START=================================
nlp-maven-plugin
%%
Copyright (C) 2017 Schrieveslaach
%%
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Lesser Public License for more details.
You should have received a copy of the GNU General Lesser Public
License along with this program. If not, see
<http://www.gnu.org/licenses/lgpl-3.0.html>.
=========================LICENSE_END==================================
-->
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>de.company</groupId>
<artifactId>domain-specific-corpus</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<language>en</language>
<named.entity.types>person,organization</named.entity.types>
</properties>
<build>
<plugins>
<plugin>
<groupId>de.schrieveslaach.nlpf</groupId>
<artifactId>nlp-maven-plugin</artifactId>
<version>1.1.0-SNAPSHOT</version>
<extensions>true</extensions>
<configuration>
<excludedTrainers>
<excludedTrainer>de.tudarmstadt.ukp.dkpro.core.stanfordnlp-gpl.*</excludedTrainer>
</excludedTrainers>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>de.tudarmstadt.ukp.dkpro.core</groupId>
<artifactId>de.tudarmstadt.ukp.dkpro.core.io.xmi-asl</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.dkpro.core</groupId>
<artifactId>de.tudarmstadt.ukp.dkpro.core.opennlp-asl</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.dkpro.core</groupId>
<artifactId>de.tudarmstadt.ukp.dkpro.core.stanfordnlp-gpl</artifactId>
<version>1.10.0</version>
</dependency>
</dependencies>
</project>
package de.company;
/*-
* ========================LICENSE_START=================================
* nlp-maven-plugin
* %%
* Copyright (C) 2017 Schrieveslaach
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Lesser Public License for more details.
*
* You should have received a copy of the GNU General Lesser Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/lgpl-3.0.html>.
* =========================LICENSE_END==================================
*/
import de.tudarmstadt.ukp.dkpro.core.stanfordnlp.StanfordSegmenter;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription;
public class CustomAnalysisEngineFactory {
public AnalysisEngineDescription createStanfordSegmenter() throws Exception {
return createEngineDescription(
StanfordSegmenter.class,
StanfordSegmenter.PARAM_LANGUAGE_FALLBACK, "en"
);
}
}
......@@ -25,6 +25,7 @@
<groupId>de.company</groupId>
<artifactId>domain-specific-corpus</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>nlp-models</packaging>
<properties>
<language>en</language>
......
......@@ -52,6 +52,7 @@ import javax.inject.Inject;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static de.schrieveslaach.nlpf.maven.plugin.JCasDataUtil.copyJCas;
......@@ -119,7 +120,7 @@ public class TrainMojo extends AbstractMojo {
private List<AnalysisEngineDescription> runTrainerPipeline(List<JCas> trainingData) {
return analysisEngineService.findTrainerDescriptions()
.parallelStream()
.filter(aed -> excludedTrainers == null || !excludedTrainers.contains(aed.getAnnotatorImplementationName()))
.filter(aed -> filterExcludedTrainers(aed))
.filter(aed -> train(trainingData, aed))
.collect(Collectors.toList());
}
......@@ -127,10 +128,15 @@ public class TrainMojo extends AbstractMojo {
private boolean filterExcludedTrainers(AnalysisEngineDescription aed) {
if (aed == null) {
return false;
} else {
//TODO: Regex to match classpath to filter with asteriks de.company.*
return !excludedTrainers.contains(aed.getAnnotatorImplementationName());
} else if (excludedTrainers != null && !excludedTrainers.isEmpty()) {
for (String excludedTrainer : excludedTrainers) {
String[] split = excludedTrainer.split("\\*");
if (aed.getAnnotatorImplementationName().contains(split[0])) {
return false;
}
}
}
return true;
}
private boolean train(List<JCas> trainingData, AnalysisEngineDescription aed) {
......
......@@ -165,36 +165,6 @@ public class IsolatedToolsTestMojoTest extends AbstractTestMojoTest {
assertThat(json, jsonEquals(jsonOpenNlpPersonTestResult()).withTolerance(0.001));
}
@Test
public void shouldFilterEngineDescriptions_ExclusionConfigured() throws Exception {
String excludedClass = "de.tudarmstadt.ukp.dkpro.core.opennlp.OpenNlpPosTagger";
createExclusionInConfiguration(excludedClass);
mockEngineDescriptionsInDescriptorsDirectory(
OPEN_NLP_PERSON_DESCRIPTION,
MY_PERSON_DESCRIPTION,
OPEN_NLP_ORGANIZATION_DESCRIPTION,
MY_ORGANIZATION_DESCRIPTION,
createEngineDescription(OpenNlpPosTagger.class)
);
provideTestJCas();
List<AnalysisEngineDescription> engineDescriptions = mojo.loadEngineDescriptions();
Set<String> implNames = engineDescriptions.stream()
.map(AnalysisEngineDescription::getAnnotatorImplementationName)
.collect(Collectors.toSet());
assertThat(implNames, not(contains(excludedClass)));
}
private void createExclusionInConfiguration(String... classNames) {
//Use reflection to set the configuration parameter
Field exclusionsField = ReflectionUtils.findField(mojo.getClass(), "exclusions");
exclusionsField.setAccessible(true);
ReflectionUtils.setField(exclusionsField, mojo, Arrays.asList(classNames));
}
@Override
protected AbstractTestMojo getMojoUnderTest() {
return mojo;
......
......@@ -198,6 +198,18 @@ public class TrainMojoTest {
assertThat(jCasOfNamedEntityTrainer, is(notNullValue()));
}
@Test
public void shouldNotTrain_TrainerExcludedInConfiguration_WithAsterik() throws Exception {
createExclusionInConfiguration("de.schrieveslaach.nlpf.maven.plugin.*");
mockTrainingPipeline(TestPosTaggerTrainer.class, TestNerTrainer.class);
mojo.execute();
assertThat(jCasOfPosTaggerTrainer, is(nullValue()));
assertThat(jCasOfNamedEntityTrainer, is(nullValue()));
}
private void createExclusionInConfiguration(String... classNames) {
//Use reflection to set the configuration parameter
Field excludedTrainersField = ReflectionUtils.findField(mojo.getClass(), "excludedTrainers");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment