Source code for medacy.pipeline_components.learners.bilstm_crf_learner

"""
BiLSTM+CRF PyTorch network and model.
"""
import logging
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from medacy.pipeline_components.learners.nn.bilstm_crf import BiLstmCrf
from medacy.pipeline_components.learners.nn.vectorizer import Vectorizer


[docs]class BiLstmCrfLearner: """ BiLSTM-CRF model class for using the network. Currently handles all vectorization as well. :ivar device: PyTorch device to use. :ivar model: Instance of BiLstmCrfNetwork to use. :ivar word_embeddings_file: File to load word embeddings from. :ivar word_vectors: Gensim word vectors object for use in configuring word embeddings. """ def __init__(self, word_embeddings, cuda_device): """Init BiLstmCrfLearner object. :param word_embeddings: Path to word embeddings file to use. :param cuda_device: Index of cuda device to use. Use -1 to use CPU. """ torch.manual_seed(1) device_string = 'cuda:%d' % cuda_device if cuda_device >= 0 else 'cpu' self.device = torch.device(device_string) self.vectorizer = Vectorizer(self.device) if word_embeddings is None: raise ValueError('BiLSTM-CRF requires word embeddings.') else: self.word_embeddings_file = word_embeddings self.word_vectors = None # Other instance attributes self.model = None self.learning_rate = 0.01 # TODO: Implement cleaner way to handle this if word_embeddings.endswith('test_word_embeddings.txt'): self.epochs = 2 self.crf_delay = 1 else: self.epochs = 40 self.crf_delay = 20
[docs] def fit(self, x_data, y_data): """Fully train model based on x and y data. self.model is set to trained model. :param x_data: List of list of tokens. :param y_data: List of list of correct labels for the tokens. """ if self.vectorizer.word_vectors is None: self.vectorizer.load_word_embeddings(self.word_embeddings_file) data = self.vectorizer.vectorize_dataset(x_data, y_data) # Create network model = BiLstmCrf( self.vectorizer.word_vectors, len(data[0][0][0][2:]), len(self.vectorizer.tag_to_index), self.device ) # Move to GPU if possible if self.device.type != 'cpu': logging.info('CUDA available. Moving model to GPU.') model = model.to(self.device) # Setup optimizer and loss function optimizer = optim.SGD(model.parameters(), lr=self.learning_rate) loss_function = nn.NLLLoss() logging.info('Training BiLSTM-CRF...') # Training loop for i in range(1, self.epochs + 1): random.shuffle(data) epoch_losses = [] for sentence, sentence_tags in data: if i <= self.crf_delay: emissions = model(sentence) predictions = F.log_softmax(emissions, dim=1) loss = loss_function(predictions, sentence_tags) else: emissions = model(sentence).unsqueeze(1) sentence_tags = sentence_tags.unsqueeze(1) loss = -model.crf(emissions, sentence_tags) optimizer.zero_grad() loss.backward() optimizer.step() epoch_losses.append(loss) average_loss = sum(epoch_losses) / len(epoch_losses) logging.info('Epoch %d average loss: %f', i, average_loss) self.model = model
[docs] def predict(self, sequences): """Use model to make predictions over a given dataset. :param sequences: Sequences to predict labels for. :return: List of list of predicted labels. """ if not self.vectorizer.word_vectors: raise RuntimeError('Loading word embeddings is required.') with torch.no_grad(): predictions = [] for sequence in sequences: vectorized_tokens = self.vectorizer.vectorize_tokens(sequence) emissions = self.model(vectorized_tokens).unsqueeze(1) tag_indices = self.model.crf.decode(emissions) predictions.append(self.vectorizer.devectorize_tag(tag_indices[0])) return predictions
[docs] def save(self, path): """Save model and other required values. :param path: Path to save model to. """ vectorizer_values = self.vectorizer.get_values() properties = { 'model': self.model, 'vectorizer_values': vectorizer_values, } if not path.endswith('.pt'): path += '.pt' torch.save(properties, path)
[docs] def load(self, path): """Load model and other required values from given path. :param path: Path of saved model. """ saved_data = torch.load(path, map_location=self.device) self.vectorizer.load_values(saved_data['vectorizer_values']) model = saved_data['model'] model.device = self.device model.eval() self.model = model.to(self.device) self.model.device = self.device self.vectorizer.load_word_embeddings(self.word_embeddings_file)