Skip to content

Commit 187a25a

Browse files
author
Benedikt Fuchs
committed
add tars tests
1 parent 5d210c1 commit 187a25a

9 files changed

+265
-41
lines changed

flair/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
cache_root = Path(os.getenv("FLAIR_CACHE_ROOT", Path(Path.home(), ".flair")))
1010

1111
# global variable: device
12-
if torch.cuda.is_available():
12+
if torch.cuda.is_available() and False:
1313
device = torch.device("cuda:0")
1414
else:
1515
device = torch.device("cpu")

flair/models/tars_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def predict_zero_shot(
298298
label_dictionary=label_dictionary,
299299
label_type="-".join(label_dictionary.get_items()),
300300
multi_label=multi_label,
301+
force_switch=True, # overwrite any older configuration
301302
)
302303

303304
try:

tests/model_test_utils.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
import flair
6-
from flair.data import Sentence, Dictionary
6+
from flair.data import Dictionary, Sentence
77
from flair.nn import Model
88
from flair.trainers import ModelTrainer
99

@@ -72,16 +72,18 @@ def build_model(self, embeddings, label_dict, **kwargs):
7272
def has_embedding(self, sentence):
7373
return sentence.get_embedding().cpu().numpy().size > 0
7474

75-
@pytest.mark.integration
76-
def test_load_use_model(self, example_sentence):
75+
@pytest.fixture
76+
def loaded_pretrained_model(self):
7777
if self.pretrained_model is None:
7878
pytest.skip("For this test `pretrained_model` needs to be set.")
79-
loaded_model = self.model_cls.load(self.pretrained_model)
79+
yield self.model_cls.load(self.pretrained_model)
8080

81-
loaded_model.predict(example_sentence)
82-
loaded_model.predict([example_sentence, self.empty_sentence])
83-
loaded_model.predict([self.empty_sentence])
84-
del loaded_model
81+
@pytest.mark.integration
82+
def test_load_use_model(self, example_sentence, loaded_pretrained_model):
83+
loaded_pretrained_model.predict(example_sentence)
84+
loaded_pretrained_model.predict([example_sentence, self.empty_sentence])
85+
loaded_pretrained_model.predict([self.empty_sentence])
86+
del loaded_pretrained_model
8587

8688
example_sentence.clear_embeddings()
8789
self.empty_sentence.clear_embeddings()
@@ -119,7 +121,9 @@ def test_train_load_use_model(self, results_base_path, corpus, embeddings, examp
119121
del loaded_model
120122

121123
@pytest.mark.integration
122-
def test_train_load_use_model_multi_corpus(self, results_base_path, multi_corpus, embeddings, example_sentence, train_test_sentence):
124+
def test_train_load_use_model_multi_corpus(
125+
self, results_base_path, multi_corpus, embeddings, example_sentence, train_test_sentence
126+
):
123127
flair.set_seed(123)
124128
label_dict = multi_corpus.make_label_dictionary(label_type=self.train_label_type)
125129

@@ -190,16 +194,13 @@ def test_forward_loss(self, labeled_sentence, embeddings):
190194
assert loss.size() == ()
191195
assert count == len(labeled_sentence.get_labels(self.train_label_type))
192196

193-
def test_load_use_model_keep_embedding(self, example_sentence):
194-
if self.pretrained_model is None:
195-
pytest.skip("For this test `pretrained_model` needs to be set.")
196-
loaded_model = self.model_cls.load(self.pretrained_model)
197+
def test_load_use_model_keep_embedding(self, example_sentence, loaded_pretrained_model):
197198

198199
assert not self.has_embedding(example_sentence)
199200

200-
loaded_model.predict(example_sentence, embedding_storage_mode="cpu")
201+
loaded_pretrained_model.predict(example_sentence, embedding_storage_mode="cpu")
201202
assert self.has_embedding(example_sentence)
202-
del loaded_model
203+
del loaded_pretrained_model
203204

204205
def test_train_load_use_model_multi_label(
205206
self, results_base_path, multi_class_corpus, embeddings, example_sentence, multiclass_train_test_sentence

tests/models/test_entity_linker.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
class TestEntityLinker(BaseModelTest):
1111
model_cls = EntityLinker
1212
train_label_type = "nel"
13-
training_args = dict(
14-
max_epochs=2
15-
)
13+
training_args = dict(max_epochs=2)
1614

1715
@pytest.fixture
1816
def embeddings(self):

tests/models/test_sequence_tagger.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
import flair
4-
from flair.embeddings import WordEmbeddings, FlairEmbeddings
4+
from flair.embeddings import FlairEmbeddings, WordEmbeddings
55
from flair.models import SequenceTagger
66
from flair.trainers import ModelTrainer
77
from tests.model_test_utils import BaseModelTest
@@ -27,31 +27,44 @@ def has_embedding(self, sentence):
2727
return False
2828
return True
2929

30+
def build_model(self, embeddings, label_dict, **kwargs):
31+
model_args = dict(self.model_args)
32+
for k in kwargs.keys():
33+
if k in model_args:
34+
del model_args[k]
35+
return self.model_cls(
36+
embeddings=embeddings,
37+
tag_dictionary=label_dict,
38+
tag_type=self.train_label_type,
39+
**model_args,
40+
**kwargs,
41+
)
42+
3043
@pytest.fixture
3144
def embeddings(self):
3245
yield WordEmbeddings("turian")
3346

3447
@pytest.fixture
3548
def corpus(self, tasks_base_path):
36-
yield flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={0: "text", 2: "ner"})
49+
yield flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={0: "text", 3: "ner"})
3750

3851
@pytest.mark.integration
39-
def test_all_tag_proba_embedding(self, example_sentence):
40-
model = self.model_cls.load(self.pretrained_model)
52+
def test_all_tag_proba_embedding(self, example_sentence, loaded_pretrained_model):
4153

42-
model.predict(example_sentence, return_probabilities_for_all_classes=True)
54+
loaded_pretrained_model.predict(example_sentence, return_probabilities_for_all_classes=True)
4355
for token in example_sentence:
44-
assert len(token.get_tags_proba_dist(model.label_type)) == len(model.label_dictionary)
56+
assert len(token.get_tags_proba_dist(loaded_pretrained_model.label_type)) == len(
57+
loaded_pretrained_model.label_dictionary
58+
)
4559
score_sum = 0.0
46-
for label in token.get_tags_proba_dist(model.label_type):
60+
for label in token.get_tags_proba_dist(loaded_pretrained_model.label_type):
4761
assert label.data_point == token
4862
score_sum += label.score
4963
assert abs(score_sum - 1.0) < 1.0e-5
5064

5165
@pytest.mark.integration
52-
def test_force_token_predictions(self, example_sentence):
53-
model = self.model_cls.load(self.pretrained_model)
54-
model.predict(example_sentence, force_token_predictions=True)
66+
def test_force_token_predictions(self, example_sentence, loaded_pretrained_model):
67+
loaded_pretrained_model.predict(example_sentence, force_token_predictions=True)
5568
assert example_sentence.get_token(3).text == "Berlin"
5669
assert example_sentence.get_token(3).tag == "S-LOC"
5770

@@ -73,13 +86,15 @@ def test_train_load_use_tagger_flair_embeddings(self, results_base_path, corpus,
7386
del loaded_model
7487

7588
@pytest.mark.integration
76-
def test_train_load_use_tagger_disjunct_tags(self, results_base_path, tasks_base_path, embeddings, example_sentence):
89+
def test_train_load_use_tagger_disjunct_tags(
90+
self, results_base_path, tasks_base_path, embeddings, example_sentence
91+
):
7792
corpus = flair.datasets.ColumnCorpus(
7893
data_folder=tasks_base_path / "fashion_disjunct",
7994
column_format={0: "text", 3: "ner"},
8095
)
81-
tag_dictionary = corpus.make_label_dictionary("ner", add_unk=False)
82-
model = self.build_model(embeddings, tag_dictionary)
96+
tag_dictionary = corpus.make_label_dictionary("ner", add_unk=True)
97+
model = self.build_model(embeddings, tag_dictionary, allow_unk_predictions=True)
8398
trainer = ModelTrainer(model, corpus)
8499

85100
trainer.train(results_base_path, shuffle=False, **self.training_args)
@@ -90,4 +105,4 @@ def test_train_load_use_tagger_disjunct_tags(self, results_base_path, tasks_base
90105
loaded_model.predict(example_sentence)
91106
loaded_model.predict([example_sentence, self.empty_sentence])
92107
loaded_model.predict([self.empty_sentence])
93-
del loaded_model
108+
del loaded_model

tests/models/test_tars_classifier.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
3+
from flair.data import Sentence
4+
from flair.datasets import ClassificationCorpus
5+
from flair.embeddings import TransformerDocumentEmbeddings
6+
from flair.models import TARSClassifier
7+
from tests.model_test_utils import BaseModelTest
8+
9+
10+
class TestTarsClassifier(BaseModelTest):
11+
model_cls = TARSClassifier
12+
train_label_type = "class"
13+
model_args = dict(task_name="2_CLASS")
14+
training_args = dict(mini_batch_size=1, max_epochs=2)
15+
pretrained_model = "tars-base"
16+
17+
@pytest.fixture
18+
def corpus(self, tasks_base_path):
19+
yield ClassificationCorpus(tasks_base_path / "imdb_underscore")
20+
21+
@pytest.fixture
22+
def embeddings(self):
23+
yield TransformerDocumentEmbeddings("distilbert-base-uncased")
24+
25+
@pytest.fixture
26+
def example_sentence(self):
27+
yield Sentence("This is great!")
28+
29+
def build_model(self, embeddings, label_dict, **kwargs):
30+
model_args = dict(self.model_args)
31+
for k in kwargs.keys():
32+
if k in model_args:
33+
del model_args[k]
34+
return self.model_cls(
35+
embeddings=embeddings,
36+
label_type=self.train_label_type,
37+
**model_args,
38+
**kwargs,
39+
)
40+
41+
def transform_corpus(self, model, corpus):
42+
model.add_and_switch_to_new_task(
43+
task_name="2_CLASS",
44+
label_dictionary=corpus.make_label_dictionary(self.train_label_type),
45+
label_type=self.train_label_type,
46+
)
47+
return corpus
48+
49+
@pytest.mark.integration
50+
def test_predict_zero_shot(self, loaded_pretrained_model):
51+
sentence = Sentence("I am so glad you liked it!")
52+
loaded_pretrained_model.predict_zero_shot(sentence, ["happy", "sad"])
53+
assert len(sentence.get_labels(loaded_pretrained_model.label_type)) == 1
54+
assert sentence.get_labels(loaded_pretrained_model.label_type)[0].value == "happy"
55+
56+
@pytest.mark.integration
57+
def test_predict_zero_shot_single_label_always_predicts(self, loaded_pretrained_model):
58+
sentence = Sentence("I hate it")
59+
loaded_pretrained_model.predict_zero_shot(sentence, ["happy", "sad"])
60+
# Ensure this is an example that predicts no classes in multilabel
61+
assert len(sentence.get_labels(loaded_pretrained_model.label_type)) == 0
62+
loaded_pretrained_model.predict_zero_shot(sentence, ["happy", "sad"], multi_label=False)
63+
assert len(sentence.get_labels(loaded_pretrained_model.label_type)) == 1
64+
assert sentence.get_labels(loaded_pretrained_model.label_type)[0].value == "sad"
65+
66+
@pytest.mark.integration
67+
def test_init_tars_and_switch(self, tasks_base_path, corpus):
68+
tars = TARSClassifier(
69+
task_name="2_CLASS",
70+
label_dictionary=corpus.make_label_dictionary(label_type="class"),
71+
label_type="class",
72+
)
73+
74+
# check if right number of classes
75+
assert len(tars.get_current_label_dictionary()) == 2
76+
77+
# switch to task with only one label
78+
tars.add_and_switch_to_new_task("1_CLASS", "one class", "testlabel")
79+
80+
# check if right number of classes
81+
assert len(tars.get_current_label_dictionary()) == 1
82+
83+
# switch to task with three labels provided as list
84+
tars.add_and_switch_to_new_task("3_CLASS", ["list 1", "list 2", "list 3"], "testlabel")
85+
86+
# check if right number of classes
87+
assert len(tars.get_current_label_dictionary()) == 3
88+
89+
# switch to task with four labels provided as set
90+
tars.add_and_switch_to_new_task("4_CLASS", {"set 1", "set 2", "set 3", "set 4"}, "testlabel")
91+
92+
# check if right number of classes
93+
assert len(tars.get_current_label_dictionary()) == 4
94+
95+
# switch to task with two labels provided as Dictionary
96+
tars.add_and_switch_to_new_task("2_CLASS_AGAIN", corpus.make_label_dictionary(label_type="class"), "testlabel")
97+
98+
# check if right number of classes
99+
assert len(tars.get_current_label_dictionary()) == 2
100+
101+
@pytest.mark.skip("embeddings are not supported in tars")
102+
def test_load_use_model_keep_embedding(self):
103+
pass
104+
105+
@pytest.mark.skip("tars needs additional setup after loading")
106+
def test_load_use_model(self):
107+
pass

tests/models/test_tars_ner.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
3+
import flair
4+
from flair.data import Sentence
5+
from flair.embeddings import TransformerWordEmbeddings
6+
from flair.models import TARSTagger
7+
from tests.model_test_utils import BaseModelTest
8+
9+
10+
class TestTarsTagger(BaseModelTest):
11+
model_cls = TARSTagger
12+
train_label_type = "ner"
13+
model_args = dict(task_name="2_NER")
14+
training_args = dict(mini_batch_size=1, max_epochs=2)
15+
pretrained_model = "tars-ner"
16+
17+
@pytest.fixture
18+
def corpus(self, tasks_base_path):
19+
yield flair.datasets.ColumnCorpus(data_folder=tasks_base_path / "fashion", column_format={0: "text", 3: "ner"})
20+
21+
@pytest.fixture
22+
def embeddings(self):
23+
yield TransformerWordEmbeddings("distilbert-base-uncased")
24+
25+
@pytest.fixture
26+
def example_sentence(self):
27+
yield Sentence("George Washington was born in Washington")
28+
29+
def build_model(self, embeddings, label_dict, **kwargs):
30+
model_args = dict(self.model_args)
31+
for k in kwargs.keys():
32+
if k in model_args:
33+
del model_args[k]
34+
return self.model_cls(
35+
embeddings=embeddings,
36+
label_type=self.train_label_type,
37+
**model_args,
38+
**kwargs,
39+
)
40+
41+
def transform_corpus(self, model, corpus):
42+
model.add_and_switch_to_new_task(
43+
task_name="2_NER",
44+
label_dictionary=corpus.make_label_dictionary(self.train_label_type),
45+
label_type=self.train_label_type,
46+
)
47+
return corpus
48+
49+
@pytest.mark.integration
50+
def test_predict_zero_shot(self, loaded_pretrained_model):
51+
sentence = Sentence("George Washington was born in Washington")
52+
loaded_pretrained_model.predict_zero_shot(sentence, ["location", "person"])
53+
assert len(sentence.get_labels("location-person")) == 2
54+
assert sorted([label.value for label in sentence.get_labels("location-person")]) == [
55+
"location",
56+
"person",
57+
]
58+
59+
@pytest.mark.integration
60+
def test_init_tars_and_switch(self, tasks_base_path, corpus):
61+
tars = TARSTagger(
62+
task_name="2_NER",
63+
label_dictionary=corpus.make_label_dictionary(label_type="ner"),
64+
label_type="ner",
65+
)
66+
67+
# check if right number of classes
68+
assert len(tars.get_current_label_dictionary()) == 10
69+
70+
# switch to task with only one label
71+
tars.add_and_switch_to_new_task("1_CLASS", "one class", "testlabel")
72+
73+
# check if right number of classes
74+
assert len(tars.get_current_label_dictionary()) == 1
75+
76+
# switch to task with three labels provided as list
77+
tars.add_and_switch_to_new_task("3_CLASS", ["list 1", "list 2", "list 3"], "testlabel")
78+
79+
# check if right number of classes
80+
assert len(tars.get_current_label_dictionary()) == 3
81+
82+
# switch to task with four labels provided as set
83+
tars.add_and_switch_to_new_task("4_CLASS", {"set 1", "set 2", "set 3", "set 4"}, "testlabel")
84+
85+
# check if right number of classes
86+
assert len(tars.get_current_label_dictionary()) == 4
87+
88+
# switch to task with two labels provided as Dictionary
89+
tars.add_and_switch_to_new_task("2_CLASS_AGAIN", corpus.make_label_dictionary(label_type="ner"), "testlabel")
90+
91+
# check if right number of classes
92+
assert len(tars.get_current_label_dictionary()) == 10
93+
94+
@pytest.mark.skip("embeddings are not supported in tars")
95+
def test_load_use_model_keep_embedding(self):
96+
pass
97+
98+
@pytest.mark.skip("tars needs additional setup after loading")
99+
def test_load_use_model(self):
100+
pass

0 commit comments

Comments
 (0)