|
1 |
| -import datetime |
2 | 1 | import inspect
|
3 | 2 | import logging
|
4 | 3 | import os
|
|
7 | 6 | import warnings
|
8 | 7 | from inspect import signature
|
9 | 8 | from pathlib import Path
|
10 |
| -from typing import Any, Dict, List, Optional, Tuple, Type, Union |
| 9 | +from typing import List, Optional, Tuple, Type, Union |
11 | 10 |
|
12 | 11 | import torch
|
13 | 12 | from torch.optim.sgd import SGD
|
|
17 | 16 | import flair.nn
|
18 | 17 | from flair.data import Corpus, Dictionary, _len_dataset
|
19 | 18 | from flair.datasets import DataLoader
|
20 |
| -from flair.nn import Model |
21 |
| -from flair.optim import ExpAnnealLR, LinearSchedulerWithWarmup |
22 | 19 | from flair.trainers.plugins import (
|
23 | 20 | CheckpointPlugin,
|
24 | 21 | LogFilePlugin,
|
|
32 | 29 | )
|
33 | 30 | from flair.trainers.plugins.functional.anneal_on_plateau import AnnealingPlugin
|
34 | 31 | from flair.trainers.plugins.functional.onecycle import OneCyclePlugin
|
35 |
| -from flair.training_utils import ( |
36 |
| - AnnealOnPlateau, |
37 |
| - identify_dynamic_embeddings, |
38 |
| - init_output_file, |
39 |
| - log_line, |
40 |
| - store_embeddings, |
41 |
| -) |
| 32 | +from flair.training_utils import identify_dynamic_embeddings, log_line, store_embeddings |
42 | 33 |
|
43 | 34 | log = logging.getLogger("flair")
|
44 | 35 |
|
@@ -394,7 +385,7 @@ def train_custom(
|
394 | 385 | # - SchedulerPlugin -> load state for anneal_with_restarts, batch_growth_annealing, logic for early stopping
|
395 | 386 | # - LossFilePlugin -> get the current epoch for loss file logging
|
396 | 387 | self.dispatch("before_training_epoch", epoch=epoch)
|
397 |
| - self.model.model_card["training_parameters"]["epoch"] = epoch # type: ignore |
| 388 | + self.model.model_card["training_parameters"]["epoch"] = epoch # type: ignore |
398 | 389 |
|
399 | 390 | current_learning_rate = [group["lr"] for group in self.optimizer.param_groups]
|
400 | 391 | momentum = [group["momentum"] if "momentum" in group else 0 for group in self.optimizer.param_groups]
|
|
0 commit comments