Fine tuning a Representation Model for Binary Sentiment Classification

This project involves:

IMDb dataset

We will use the IMDb dataset from Hugging Face’s datasets library to fine-tune our model for binary sentiment classification. The dataset consists of movie reviews labeled as either positive or negative. The original dataset has balanced train and test datasets, each containing 25,000 labeled samples (reviews) and an additional 50,000 unlabeled samples for unsupervised learning.

let’s load and explore the original IMDb dataset.

Code
from datasets import load_dataset, concatenate_datasets

# Load the IMDb dataset
imdb_dataset = load_dataset("imdb")
Code
# Print the new imdb dataset

print(imdb_dataset)
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
Code
# Inspecting the dataset structure

print(imdb_dataset["train"].features)
{'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['neg', 'pos'], id=None)}
Code
# Checking the first review to see what it looks like

print(imdb_dataset["train"][0]["text"])
I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it's not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn't have much of a plot.
Code
# Checking the label of the first review

print(imdb_dataset["train"][0]["label"])
0

Creating a Smaller Dataset

Here, we will only use a subset of the IMDb dataset to reduce the computational requirements. (as training the original full dataset alone on Google Colab’s T4 GPU previously took me approximately 40 minutes).

So, we will create a new IMDb dataset consisting of a balanced training set with 8000 samples (reviews) and a test set with 2000 samples, both drawn from the original IMDb dataset.

Code
# Define a function to balance the dataset
def balance_dataset(dataset, label_col, num_samples):
    # Filter positive and negative examples
    positive_samples = dataset.filter(lambda example: example[label_col] == 1)
    negative_samples = dataset.filter(lambda example: example[label_col] == 0)

    # Subsample both to the desired number
    positive_samples = positive_samples.shuffle(seed=42).select(range(num_samples // 2))
    negative_samples = negative_samples.shuffle(seed=42).select(range(num_samples // 2))

    # Concatenate positive and negative examples to form a balanced dataset
    balanced_dataset = concatenate_datasets([positive_samples, negative_samples]).shuffle(seed=42)

    return balanced_dataset
Code
# Create a balanced train and test dataset
train_data = balance_dataset(imdb_dataset["train"], "label", 8000)
test_data = balance_dataset(imdb_dataset["test"], "label", 2000)

# Checking the datasets
print(f"train_data:\n {train_data}", end='\n\n')
print(f"test_data:\n {test_data}")
train_data:
 Dataset({
    features: ['text', 'label'],
    num_rows: 8000
})

test_data:
 Dataset({
    features: ['text', 'label'],
    num_rows: 2000
})

1 Fine-tuning the bert-base-cased model with classifier

We will use the bert-base-cased model, a variant of the pre-trained BERT model, for our binary sentiment classification task. This model has been pre-trained on a large corpus of text for tasks like masked language modeling and next sentence prediction, but it is not specifically trained for sentiment classification. Therefore, fine-tuning the model makes it feasible to use it for this specific classification task, which in our case is movie review sentiment classification.

We will fine-tune the bert-base-cased model along with the classification head (aka classifier) as a single architecture. During training, the weights of both the pre-trained BERT layers and the classification layer are updated using backpropagation. This fine-tuning process adapts the general language understanding of BERT to the specific task of sentiment classification by learning from the IMDb dataset.

1.0.1 Loading the Model and Tokenizer

Code
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding

# Load Model and Tokenizer
model_id = "bert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Pad to the longest sequence in the batch
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

The AutoModelForSequenceClassification class adds a classification head on top of the BERT model. This head is typically a simple fully connected (dense) layer, which takes the output from BERT and maps it to the desired number of classes. In this case, there are 2 classes: positive (1) or negative (0). This is why we pass the argument num_labels=2.

The DataCollatorWithPadding ensures that all sequences in a batch are padded to the same length during training, making it easier to process batches of text data.

1.0.2 Tokenizing the data.

we will define preprocess_function(), which uses the tokenizer to tokenize the text input data. The truncation=True option makes sure that sequences exceeding the model’s maximum token length are truncated.

Code
# Define preprocessing function
def preprocess_function(examples):
   """Tokenize input data"""
   return tokenizer(examples["text"], truncation=True)

# # Tokenize train and test data
tokenized_train = train_data.map(preprocess_function, batched=True)
tokenized_test = test_data.map(preprocess_function, batched=True)

1.0.3 Define F1 Score Evaluation Metric

We will define a custom compute_metrics() function that will be later passed as an argument to the Trainer class for evaluating the model’s performance. In this case, we evaluate our model using the F1 score, which is calculated using the evaluate library’s pre-loaded F1 metric and returned as a dictionary.

Code
import numpy as np
import evaluate


def compute_metrics(eval_pred):

    """Calculate F1 score"""

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    load_f1 = evaluate.load("f1")
    f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]

    return {"f1": f1}

1.1 Train and Evaluate the model

Now, we define the hyperparameters that we want to tune in the TrainingArguments class:

  • learning_rate=2e-5: The learning rate used for updating model parameters.
  • per_device_train_batch_size=16 and per_device_eval_batch_size=16: Batch size during training and evaluation.
  • num_train_epochs=1: The number of epochs (complete passes through the dataset).
  • weight_decay=0.01: A regularization term to prevent overfitting.
  • save_strategy="epoch": The model will be saved at the end of each epoch.
  • report_to="none": Disables reporting to external logging services.
Code
from transformers import TrainingArguments, Trainer

# Training arguments for parameter tuning
training_args = TrainingArguments(
   "model",
   learning_rate=2e-5,
   per_device_train_batch_size=16,
   per_device_eval_batch_size=16,
   num_train_epochs=1,
   weight_decay=0.01,
   save_strategy="epoch",
   report_to="none"
)

# Trainer which executes the training process
trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=tokenized_train,
   eval_dataset=tokenized_test,
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)

We will perform model training and evaluation using the Trainer class.

Code
trainer.train()
TrainOutput(global_step=500, training_loss=0.31193878173828127, metrics={'train_runtime': 742.959, 'train_samples_per_second': 10.768, 'train_steps_per_second': 0.673, 'total_flos': 2084785113806400.0, 'train_loss': 0.31193878173828127, 'epoch': 1.0})

The output of trainer.evaluate() will be a dictionary containing default metrics, along with the F1 score metric that we defined as part of the compute_metrics() function (since F1 was specified as the custom evaluation metric).

Code
trainer.evaluate()
{'eval_loss': 0.19807493686676025,
 'eval_f1': 0.9257425742574258,
 'eval_runtime': 59.8717,
 'eval_samples_per_second': 33.405,
 'eval_steps_per_second': 2.088,
 'epoch': 1.0}

From the results, we can observe that the training time for fine-tuning is approximately 742.95 seconds, and we achieved an F1 score of 0.925

2 Freezing some layers in the main BERT and fine-tuning

In the previous section, the fine-tuning process took about 15 minutes. To reduce training time, we can freeze certain layers in the model and only fine-tune specific layers while maintaining a reasonable performance.

Code
# Load Model and Tokenizer

model_id = "bert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Note that we have reloaded the model again instead of using the previously fine-tuned model because we do not want to freeze layers in an already fine-tuned model. Doing so would mean training a model that has already been fine-tuned, which would defeat the purpose of what we are trying to achieve.

Now, let’s print out the layers of the model.

Code
# Print layer names
for name, param in model.named_parameters():
    print(name)
bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attention.self.query.weight
bert.encoder.layer.1.attention.self.query.bias
bert.encoder.layer.1.attention.self.key.weight
bert.encoder.layer.1.attention.self.key.bias
bert.encoder.layer.1.attention.self.value.weight
bert.encoder.layer.1.attention.self.value.bias
bert.encoder.layer.1.attention.output.dense.weight
bert.encoder.layer.1.attention.output.dense.bias
bert.encoder.layer.1.attention.output.LayerNorm.weight
bert.encoder.layer.1.attention.output.LayerNorm.bias
bert.encoder.layer.1.intermediate.dense.weight
bert.encoder.layer.1.intermediate.dense.bias
bert.encoder.layer.1.output.dense.weight
bert.encoder.layer.1.output.dense.bias
bert.encoder.layer.1.output.LayerNorm.weight
bert.encoder.layer.1.output.LayerNorm.bias
bert.encoder.layer.2.attention.self.query.weight
bert.encoder.layer.2.attention.self.query.bias
bert.encoder.layer.2.attention.self.key.weight
bert.encoder.layer.2.attention.self.key.bias
bert.encoder.layer.2.attention.self.value.weight
bert.encoder.layer.2.attention.self.value.bias
bert.encoder.layer.2.attention.output.dense.weight
bert.encoder.layer.2.attention.output.dense.bias
bert.encoder.layer.2.attention.output.LayerNorm.weight
bert.encoder.layer.2.attention.output.LayerNorm.bias
bert.encoder.layer.2.intermediate.dense.weight
bert.encoder.layer.2.intermediate.dense.bias
bert.encoder.layer.2.output.dense.weight
bert.encoder.layer.2.output.dense.bias
bert.encoder.layer.2.output.LayerNorm.weight
bert.encoder.layer.2.output.LayerNorm.bias
bert.encoder.layer.3.attention.self.query.weight
bert.encoder.layer.3.attention.self.query.bias
bert.encoder.layer.3.attention.self.key.weight
bert.encoder.layer.3.attention.self.key.bias
bert.encoder.layer.3.attention.self.value.weight
bert.encoder.layer.3.attention.self.value.bias
bert.encoder.layer.3.attention.output.dense.weight
bert.encoder.layer.3.attention.output.dense.bias
bert.encoder.layer.3.attention.output.LayerNorm.weight
bert.encoder.layer.3.attention.output.LayerNorm.bias
bert.encoder.layer.3.intermediate.dense.weight
bert.encoder.layer.3.intermediate.dense.bias
bert.encoder.layer.3.output.dense.weight
bert.encoder.layer.3.output.dense.bias
bert.encoder.layer.3.output.LayerNorm.weight
bert.encoder.layer.3.output.LayerNorm.bias
bert.encoder.layer.4.attention.self.query.weight
bert.encoder.layer.4.attention.self.query.bias
bert.encoder.layer.4.attention.self.key.weight
bert.encoder.layer.4.attention.self.key.bias
bert.encoder.layer.4.attention.self.value.weight
bert.encoder.layer.4.attention.self.value.bias
bert.encoder.layer.4.attention.output.dense.weight
bert.encoder.layer.4.attention.output.dense.bias
bert.encoder.layer.4.attention.output.LayerNorm.weight
bert.encoder.layer.4.attention.output.LayerNorm.bias
bert.encoder.layer.4.intermediate.dense.weight
bert.encoder.layer.4.intermediate.dense.bias
bert.encoder.layer.4.output.dense.weight
bert.encoder.layer.4.output.dense.bias
bert.encoder.layer.4.output.LayerNorm.weight
bert.encoder.layer.4.output.LayerNorm.bias
bert.encoder.layer.5.attention.self.query.weight
bert.encoder.layer.5.attention.self.query.bias
bert.encoder.layer.5.attention.self.key.weight
bert.encoder.layer.5.attention.self.key.bias
bert.encoder.layer.5.attention.self.value.weight
bert.encoder.layer.5.attention.self.value.bias
bert.encoder.layer.5.attention.output.dense.weight
bert.encoder.layer.5.attention.output.dense.bias
bert.encoder.layer.5.attention.output.LayerNorm.weight
bert.encoder.layer.5.attention.output.LayerNorm.bias
bert.encoder.layer.5.intermediate.dense.weight
bert.encoder.layer.5.intermediate.dense.bias
bert.encoder.layer.5.output.dense.weight
bert.encoder.layer.5.output.dense.bias
bert.encoder.layer.5.output.LayerNorm.weight
bert.encoder.layer.5.output.LayerNorm.bias
bert.encoder.layer.6.attention.self.query.weight
bert.encoder.layer.6.attention.self.query.bias
bert.encoder.layer.6.attention.self.key.weight
bert.encoder.layer.6.attention.self.key.bias
bert.encoder.layer.6.attention.self.value.weight
bert.encoder.layer.6.attention.self.value.bias
bert.encoder.layer.6.attention.output.dense.weight
bert.encoder.layer.6.attention.output.dense.bias
bert.encoder.layer.6.attention.output.LayerNorm.weight
bert.encoder.layer.6.attention.output.LayerNorm.bias
bert.encoder.layer.6.intermediate.dense.weight
bert.encoder.layer.6.intermediate.dense.bias
bert.encoder.layer.6.output.dense.weight
bert.encoder.layer.6.output.dense.bias
bert.encoder.layer.6.output.LayerNorm.weight
bert.encoder.layer.6.output.LayerNorm.bias
bert.encoder.layer.7.attention.self.query.weight
bert.encoder.layer.7.attention.self.query.bias
bert.encoder.layer.7.attention.self.key.weight
bert.encoder.layer.7.attention.self.key.bias
bert.encoder.layer.7.attention.self.value.weight
bert.encoder.layer.7.attention.self.value.bias
bert.encoder.layer.7.attention.output.dense.weight
bert.encoder.layer.7.attention.output.dense.bias
bert.encoder.layer.7.attention.output.LayerNorm.weight
bert.encoder.layer.7.attention.output.LayerNorm.bias
bert.encoder.layer.7.intermediate.dense.weight
bert.encoder.layer.7.intermediate.dense.bias
bert.encoder.layer.7.output.dense.weight
bert.encoder.layer.7.output.dense.bias
bert.encoder.layer.7.output.LayerNorm.weight
bert.encoder.layer.7.output.LayerNorm.bias
bert.encoder.layer.8.attention.self.query.weight
bert.encoder.layer.8.attention.self.query.bias
bert.encoder.layer.8.attention.self.key.weight
bert.encoder.layer.8.attention.self.key.bias
bert.encoder.layer.8.attention.self.value.weight
bert.encoder.layer.8.attention.self.value.bias
bert.encoder.layer.8.attention.output.dense.weight
bert.encoder.layer.8.attention.output.dense.bias
bert.encoder.layer.8.attention.output.LayerNorm.weight
bert.encoder.layer.8.attention.output.LayerNorm.bias
bert.encoder.layer.8.intermediate.dense.weight
bert.encoder.layer.8.intermediate.dense.bias
bert.encoder.layer.8.output.dense.weight
bert.encoder.layer.8.output.dense.bias
bert.encoder.layer.8.output.LayerNorm.weight
bert.encoder.layer.8.output.LayerNorm.bias
bert.encoder.layer.9.attention.self.query.weight
bert.encoder.layer.9.attention.self.query.bias
bert.encoder.layer.9.attention.self.key.weight
bert.encoder.layer.9.attention.self.key.bias
bert.encoder.layer.9.attention.self.value.weight
bert.encoder.layer.9.attention.self.value.bias
bert.encoder.layer.9.attention.output.dense.weight
bert.encoder.layer.9.attention.output.dense.bias
bert.encoder.layer.9.attention.output.LayerNorm.weight
bert.encoder.layer.9.attention.output.LayerNorm.bias
bert.encoder.layer.9.intermediate.dense.weight
bert.encoder.layer.9.intermediate.dense.bias
bert.encoder.layer.9.output.dense.weight
bert.encoder.layer.9.output.dense.bias
bert.encoder.layer.9.output.LayerNorm.weight
bert.encoder.layer.9.output.LayerNorm.bias
bert.encoder.layer.10.attention.self.query.weight
bert.encoder.layer.10.attention.self.query.bias
bert.encoder.layer.10.attention.self.key.weight
bert.encoder.layer.10.attention.self.key.bias
bert.encoder.layer.10.attention.self.value.weight
bert.encoder.layer.10.attention.self.value.bias
bert.encoder.layer.10.attention.output.dense.weight
bert.encoder.layer.10.attention.output.dense.bias
bert.encoder.layer.10.attention.output.LayerNorm.weight
bert.encoder.layer.10.attention.output.LayerNorm.bias
bert.encoder.layer.10.intermediate.dense.weight
bert.encoder.layer.10.intermediate.dense.bias
bert.encoder.layer.10.output.dense.weight
bert.encoder.layer.10.output.dense.bias
bert.encoder.layer.10.output.LayerNorm.weight
bert.encoder.layer.10.output.LayerNorm.bias
bert.encoder.layer.11.attention.self.query.weight
bert.encoder.layer.11.attention.self.query.bias
bert.encoder.layer.11.attention.self.key.weight
bert.encoder.layer.11.attention.self.key.bias
bert.encoder.layer.11.attention.self.value.weight
bert.encoder.layer.11.attention.self.value.bias
bert.encoder.layer.11.attention.output.dense.weight
bert.encoder.layer.11.attention.output.dense.bias
bert.encoder.layer.11.attention.output.LayerNorm.weight
bert.encoder.layer.11.attention.output.LayerNorm.bias
bert.encoder.layer.11.intermediate.dense.weight
bert.encoder.layer.11.intermediate.dense.bias
bert.encoder.layer.11.output.dense.weight
bert.encoder.layer.11.output.dense.bias
bert.encoder.layer.11.output.LayerNorm.weight
bert.encoder.layer.11.output.LayerNorm.bias
bert.pooler.dense.weight
bert.pooler.dense.bias
classifier.weight
classifier.bias

We can see that we have 12 (0-11) encoder blocks consisting of attention heads, dense networks, and layer normalization.

So, let’s freeze encoder blocks (0-9) and only allow two encoder blocks, along with the classification head, to be trainable. This reduces computational power by only updating part of the pre-trained model.

Code
# Encoder block 10 starts at index 165
# So we freeze everything before that block using the index
for index, (name, param) in enumerate(model.named_parameters()):
    if index < 165:
        param.requires_grad = False
Code
# Checking whether the model was correctly updated
for index, (name, param) in enumerate(model.named_parameters()):
     print(f"Parameter: {index}{name} ----- {param.requires_grad}")
Parameter: 0bert.embeddings.word_embeddings.weight ----- False
Parameter: 1bert.embeddings.position_embeddings.weight ----- False
Parameter: 2bert.embeddings.token_type_embeddings.weight ----- False
Parameter: 3bert.embeddings.LayerNorm.weight ----- False
Parameter: 4bert.embeddings.LayerNorm.bias ----- False
Parameter: 5bert.encoder.layer.0.attention.self.query.weight ----- False
Parameter: 6bert.encoder.layer.0.attention.self.query.bias ----- False
Parameter: 7bert.encoder.layer.0.attention.self.key.weight ----- False
Parameter: 8bert.encoder.layer.0.attention.self.key.bias ----- False
Parameter: 9bert.encoder.layer.0.attention.self.value.weight ----- False
Parameter: 10bert.encoder.layer.0.attention.self.value.bias ----- False
Parameter: 11bert.encoder.layer.0.attention.output.dense.weight ----- False
Parameter: 12bert.encoder.layer.0.attention.output.dense.bias ----- False
Parameter: 13bert.encoder.layer.0.attention.output.LayerNorm.weight ----- False
Parameter: 14bert.encoder.layer.0.attention.output.LayerNorm.bias ----- False
Parameter: 15bert.encoder.layer.0.intermediate.dense.weight ----- False
Parameter: 16bert.encoder.layer.0.intermediate.dense.bias ----- False
Parameter: 17bert.encoder.layer.0.output.dense.weight ----- False
Parameter: 18bert.encoder.layer.0.output.dense.bias ----- False
Parameter: 19bert.encoder.layer.0.output.LayerNorm.weight ----- False
Parameter: 20bert.encoder.layer.0.output.LayerNorm.bias ----- False
Parameter: 21bert.encoder.layer.1.attention.self.query.weight ----- False
Parameter: 22bert.encoder.layer.1.attention.self.query.bias ----- False
Parameter: 23bert.encoder.layer.1.attention.self.key.weight ----- False
Parameter: 24bert.encoder.layer.1.attention.self.key.bias ----- False
Parameter: 25bert.encoder.layer.1.attention.self.value.weight ----- False
Parameter: 26bert.encoder.layer.1.attention.self.value.bias ----- False
Parameter: 27bert.encoder.layer.1.attention.output.dense.weight ----- False
Parameter: 28bert.encoder.layer.1.attention.output.dense.bias ----- False
Parameter: 29bert.encoder.layer.1.attention.output.LayerNorm.weight ----- False
Parameter: 30bert.encoder.layer.1.attention.output.LayerNorm.bias ----- False
Parameter: 31bert.encoder.layer.1.intermediate.dense.weight ----- False
Parameter: 32bert.encoder.layer.1.intermediate.dense.bias ----- False
Parameter: 33bert.encoder.layer.1.output.dense.weight ----- False
Parameter: 34bert.encoder.layer.1.output.dense.bias ----- False
Parameter: 35bert.encoder.layer.1.output.LayerNorm.weight ----- False
Parameter: 36bert.encoder.layer.1.output.LayerNorm.bias ----- False
Parameter: 37bert.encoder.layer.2.attention.self.query.weight ----- False
Parameter: 38bert.encoder.layer.2.attention.self.query.bias ----- False
Parameter: 39bert.encoder.layer.2.attention.self.key.weight ----- False
Parameter: 40bert.encoder.layer.2.attention.self.key.bias ----- False
Parameter: 41bert.encoder.layer.2.attention.self.value.weight ----- False
Parameter: 42bert.encoder.layer.2.attention.self.value.bias ----- False
Parameter: 43bert.encoder.layer.2.attention.output.dense.weight ----- False
Parameter: 44bert.encoder.layer.2.attention.output.dense.bias ----- False
Parameter: 45bert.encoder.layer.2.attention.output.LayerNorm.weight ----- False
Parameter: 46bert.encoder.layer.2.attention.output.LayerNorm.bias ----- False
Parameter: 47bert.encoder.layer.2.intermediate.dense.weight ----- False
Parameter: 48bert.encoder.layer.2.intermediate.dense.bias ----- False
Parameter: 49bert.encoder.layer.2.output.dense.weight ----- False
Parameter: 50bert.encoder.layer.2.output.dense.bias ----- False
Parameter: 51bert.encoder.layer.2.output.LayerNorm.weight ----- False
Parameter: 52bert.encoder.layer.2.output.LayerNorm.bias ----- False
Parameter: 53bert.encoder.layer.3.attention.self.query.weight ----- False
Parameter: 54bert.encoder.layer.3.attention.self.query.bias ----- False
Parameter: 55bert.encoder.layer.3.attention.self.key.weight ----- False
Parameter: 56bert.encoder.layer.3.attention.self.key.bias ----- False
Parameter: 57bert.encoder.layer.3.attention.self.value.weight ----- False
Parameter: 58bert.encoder.layer.3.attention.self.value.bias ----- False
Parameter: 59bert.encoder.layer.3.attention.output.dense.weight ----- False
Parameter: 60bert.encoder.layer.3.attention.output.dense.bias ----- False
Parameter: 61bert.encoder.layer.3.attention.output.LayerNorm.weight ----- False
Parameter: 62bert.encoder.layer.3.attention.output.LayerNorm.bias ----- False
Parameter: 63bert.encoder.layer.3.intermediate.dense.weight ----- False
Parameter: 64bert.encoder.layer.3.intermediate.dense.bias ----- False
Parameter: 65bert.encoder.layer.3.output.dense.weight ----- False
Parameter: 66bert.encoder.layer.3.output.dense.bias ----- False
Parameter: 67bert.encoder.layer.3.output.LayerNorm.weight ----- False
Parameter: 68bert.encoder.layer.3.output.LayerNorm.bias ----- False
Parameter: 69bert.encoder.layer.4.attention.self.query.weight ----- False
Parameter: 70bert.encoder.layer.4.attention.self.query.bias ----- False
Parameter: 71bert.encoder.layer.4.attention.self.key.weight ----- False
Parameter: 72bert.encoder.layer.4.attention.self.key.bias ----- False
Parameter: 73bert.encoder.layer.4.attention.self.value.weight ----- False
Parameter: 74bert.encoder.layer.4.attention.self.value.bias ----- False
Parameter: 75bert.encoder.layer.4.attention.output.dense.weight ----- False
Parameter: 76bert.encoder.layer.4.attention.output.dense.bias ----- False
Parameter: 77bert.encoder.layer.4.attention.output.LayerNorm.weight ----- False
Parameter: 78bert.encoder.layer.4.attention.output.LayerNorm.bias ----- False
Parameter: 79bert.encoder.layer.4.intermediate.dense.weight ----- False
Parameter: 80bert.encoder.layer.4.intermediate.dense.bias ----- False
Parameter: 81bert.encoder.layer.4.output.dense.weight ----- False
Parameter: 82bert.encoder.layer.4.output.dense.bias ----- False
Parameter: 83bert.encoder.layer.4.output.LayerNorm.weight ----- False
Parameter: 84bert.encoder.layer.4.output.LayerNorm.bias ----- False
Parameter: 85bert.encoder.layer.5.attention.self.query.weight ----- False
Parameter: 86bert.encoder.layer.5.attention.self.query.bias ----- False
Parameter: 87bert.encoder.layer.5.attention.self.key.weight ----- False
Parameter: 88bert.encoder.layer.5.attention.self.key.bias ----- False
Parameter: 89bert.encoder.layer.5.attention.self.value.weight ----- False
Parameter: 90bert.encoder.layer.5.attention.self.value.bias ----- False
Parameter: 91bert.encoder.layer.5.attention.output.dense.weight ----- False
Parameter: 92bert.encoder.layer.5.attention.output.dense.bias ----- False
Parameter: 93bert.encoder.layer.5.attention.output.LayerNorm.weight ----- False
Parameter: 94bert.encoder.layer.5.attention.output.LayerNorm.bias ----- False
Parameter: 95bert.encoder.layer.5.intermediate.dense.weight ----- False
Parameter: 96bert.encoder.layer.5.intermediate.dense.bias ----- False
Parameter: 97bert.encoder.layer.5.output.dense.weight ----- False
Parameter: 98bert.encoder.layer.5.output.dense.bias ----- False
Parameter: 99bert.encoder.layer.5.output.LayerNorm.weight ----- False
Parameter: 100bert.encoder.layer.5.output.LayerNorm.bias ----- False
Parameter: 101bert.encoder.layer.6.attention.self.query.weight ----- False
Parameter: 102bert.encoder.layer.6.attention.self.query.bias ----- False
Parameter: 103bert.encoder.layer.6.attention.self.key.weight ----- False
Parameter: 104bert.encoder.layer.6.attention.self.key.bias ----- False
Parameter: 105bert.encoder.layer.6.attention.self.value.weight ----- False
Parameter: 106bert.encoder.layer.6.attention.self.value.bias ----- False
Parameter: 107bert.encoder.layer.6.attention.output.dense.weight ----- False
Parameter: 108bert.encoder.layer.6.attention.output.dense.bias ----- False
Parameter: 109bert.encoder.layer.6.attention.output.LayerNorm.weight ----- False
Parameter: 110bert.encoder.layer.6.attention.output.LayerNorm.bias ----- False
Parameter: 111bert.encoder.layer.6.intermediate.dense.weight ----- False
Parameter: 112bert.encoder.layer.6.intermediate.dense.bias ----- False
Parameter: 113bert.encoder.layer.6.output.dense.weight ----- False
Parameter: 114bert.encoder.layer.6.output.dense.bias ----- False
Parameter: 115bert.encoder.layer.6.output.LayerNorm.weight ----- False
Parameter: 116bert.encoder.layer.6.output.LayerNorm.bias ----- False
Parameter: 117bert.encoder.layer.7.attention.self.query.weight ----- False
Parameter: 118bert.encoder.layer.7.attention.self.query.bias ----- False
Parameter: 119bert.encoder.layer.7.attention.self.key.weight ----- False
Parameter: 120bert.encoder.layer.7.attention.self.key.bias ----- False
Parameter: 121bert.encoder.layer.7.attention.self.value.weight ----- False
Parameter: 122bert.encoder.layer.7.attention.self.value.bias ----- False
Parameter: 123bert.encoder.layer.7.attention.output.dense.weight ----- False
Parameter: 124bert.encoder.layer.7.attention.output.dense.bias ----- False
Parameter: 125bert.encoder.layer.7.attention.output.LayerNorm.weight ----- False
Parameter: 126bert.encoder.layer.7.attention.output.LayerNorm.bias ----- False
Parameter: 127bert.encoder.layer.7.intermediate.dense.weight ----- False
Parameter: 128bert.encoder.layer.7.intermediate.dense.bias ----- False
Parameter: 129bert.encoder.layer.7.output.dense.weight ----- False
Parameter: 130bert.encoder.layer.7.output.dense.bias ----- False
Parameter: 131bert.encoder.layer.7.output.LayerNorm.weight ----- False
Parameter: 132bert.encoder.layer.7.output.LayerNorm.bias ----- False
Parameter: 133bert.encoder.layer.8.attention.self.query.weight ----- False
Parameter: 134bert.encoder.layer.8.attention.self.query.bias ----- False
Parameter: 135bert.encoder.layer.8.attention.self.key.weight ----- False
Parameter: 136bert.encoder.layer.8.attention.self.key.bias ----- False
Parameter: 137bert.encoder.layer.8.attention.self.value.weight ----- False
Parameter: 138bert.encoder.layer.8.attention.self.value.bias ----- False
Parameter: 139bert.encoder.layer.8.attention.output.dense.weight ----- False
Parameter: 140bert.encoder.layer.8.attention.output.dense.bias ----- False
Parameter: 141bert.encoder.layer.8.attention.output.LayerNorm.weight ----- False
Parameter: 142bert.encoder.layer.8.attention.output.LayerNorm.bias ----- False
Parameter: 143bert.encoder.layer.8.intermediate.dense.weight ----- False
Parameter: 144bert.encoder.layer.8.intermediate.dense.bias ----- False
Parameter: 145bert.encoder.layer.8.output.dense.weight ----- False
Parameter: 146bert.encoder.layer.8.output.dense.bias ----- False
Parameter: 147bert.encoder.layer.8.output.LayerNorm.weight ----- False
Parameter: 148bert.encoder.layer.8.output.LayerNorm.bias ----- False
Parameter: 149bert.encoder.layer.9.attention.self.query.weight ----- False
Parameter: 150bert.encoder.layer.9.attention.self.query.bias ----- False
Parameter: 151bert.encoder.layer.9.attention.self.key.weight ----- False
Parameter: 152bert.encoder.layer.9.attention.self.key.bias ----- False
Parameter: 153bert.encoder.layer.9.attention.self.value.weight ----- False
Parameter: 154bert.encoder.layer.9.attention.self.value.bias ----- False
Parameter: 155bert.encoder.layer.9.attention.output.dense.weight ----- False
Parameter: 156bert.encoder.layer.9.attention.output.dense.bias ----- False
Parameter: 157bert.encoder.layer.9.attention.output.LayerNorm.weight ----- False
Parameter: 158bert.encoder.layer.9.attention.output.LayerNorm.bias ----- False
Parameter: 159bert.encoder.layer.9.intermediate.dense.weight ----- False
Parameter: 160bert.encoder.layer.9.intermediate.dense.bias ----- False
Parameter: 161bert.encoder.layer.9.output.dense.weight ----- False
Parameter: 162bert.encoder.layer.9.output.dense.bias ----- False
Parameter: 163bert.encoder.layer.9.output.LayerNorm.weight ----- False
Parameter: 164bert.encoder.layer.9.output.LayerNorm.bias ----- False
Parameter: 165bert.encoder.layer.10.attention.self.query.weight ----- True
Parameter: 166bert.encoder.layer.10.attention.self.query.bias ----- True
Parameter: 167bert.encoder.layer.10.attention.self.key.weight ----- True
Parameter: 168bert.encoder.layer.10.attention.self.key.bias ----- True
Parameter: 169bert.encoder.layer.10.attention.self.value.weight ----- True
Parameter: 170bert.encoder.layer.10.attention.self.value.bias ----- True
Parameter: 171bert.encoder.layer.10.attention.output.dense.weight ----- True
Parameter: 172bert.encoder.layer.10.attention.output.dense.bias ----- True
Parameter: 173bert.encoder.layer.10.attention.output.LayerNorm.weight ----- True
Parameter: 174bert.encoder.layer.10.attention.output.LayerNorm.bias ----- True
Parameter: 175bert.encoder.layer.10.intermediate.dense.weight ----- True
Parameter: 176bert.encoder.layer.10.intermediate.dense.bias ----- True
Parameter: 177bert.encoder.layer.10.output.dense.weight ----- True
Parameter: 178bert.encoder.layer.10.output.dense.bias ----- True
Parameter: 179bert.encoder.layer.10.output.LayerNorm.weight ----- True
Parameter: 180bert.encoder.layer.10.output.LayerNorm.bias ----- True
Parameter: 181bert.encoder.layer.11.attention.self.query.weight ----- True
Parameter: 182bert.encoder.layer.11.attention.self.query.bias ----- True
Parameter: 183bert.encoder.layer.11.attention.self.key.weight ----- True
Parameter: 184bert.encoder.layer.11.attention.self.key.bias ----- True
Parameter: 185bert.encoder.layer.11.attention.self.value.weight ----- True
Parameter: 186bert.encoder.layer.11.attention.self.value.bias ----- True
Parameter: 187bert.encoder.layer.11.attention.output.dense.weight ----- True
Parameter: 188bert.encoder.layer.11.attention.output.dense.bias ----- True
Parameter: 189bert.encoder.layer.11.attention.output.LayerNorm.weight ----- True
Parameter: 190bert.encoder.layer.11.attention.output.LayerNorm.bias ----- True
Parameter: 191bert.encoder.layer.11.intermediate.dense.weight ----- True
Parameter: 192bert.encoder.layer.11.intermediate.dense.bias ----- True
Parameter: 193bert.encoder.layer.11.output.dense.weight ----- True
Parameter: 194bert.encoder.layer.11.output.dense.bias ----- True
Parameter: 195bert.encoder.layer.11.output.LayerNorm.weight ----- True
Parameter: 196bert.encoder.layer.11.output.LayerNorm.bias ----- True
Parameter: 197bert.pooler.dense.weight ----- True
Parameter: 198bert.pooler.dense.bias ----- True
Parameter: 199classifier.weight ----- True
Parameter: 200classifier.bias ----- True

2.1 Train and Evalute the partially frozen layer model

Code
trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=tokenized_train,
   eval_dataset=tokenized_test,
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)


trainer.train()
TrainOutput(global_step=500, training_loss=0.389836181640625, metrics={'train_runtime': 322.8201, 'train_samples_per_second': 24.782, 'train_steps_per_second': 1.549, 'total_flos': 2084785113806400.0, 'train_loss': 0.389836181640625, 'epoch': 1.0})
Code
trainer.evaluate()
{'eval_loss': 0.244190514087677,
 'eval_f1': 0.9073610415623435,
 'eval_runtime': 59.2922,
 'eval_samples_per_second': 33.731,
 'eval_steps_per_second': 2.108,
 'epoch': 1.0}

We observed that the training time was reduced to 322.82 seconds, more than half the time of the fully fine-tuned model, while maintaining good performance with an F1 score of 0.90

3 Freezing all layers in the main BERT and fine-tuning

For experimentation, we will now freeze all the layers in the BERT model, allowing only the classification head to be trainable. We will then evaluate how the model performs.

Code
# Load Model and Tokenizer

model_id = "bert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Code
for name, param in model.named_parameters():

     # Trainable classification head
     if name.startswith("classifier"):
        param.requires_grad = True

      # Freeze everything else
     else:
        param.requires_grad = False
Code
# We can check whether the model was correctly updated

for name, param in model.named_parameters():
     print(f"Parameter: {name} ----- {param.requires_grad}")
Parameter: bert.embeddings.word_embeddings.weight ----- False
Parameter: bert.embeddings.position_embeddings.weight ----- False
Parameter: bert.embeddings.token_type_embeddings.weight ----- False
Parameter: bert.embeddings.LayerNorm.weight ----- False
Parameter: bert.embeddings.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.0.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.0.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.0.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.0.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.0.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.0.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.0.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.0.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.0.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.0.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.0.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.0.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.0.output.dense.weight ----- False
Parameter: bert.encoder.layer.0.output.dense.bias ----- False
Parameter: bert.encoder.layer.0.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.0.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.1.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.1.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.1.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.1.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.1.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.1.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.1.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.1.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.1.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.1.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.1.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.1.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.1.output.dense.weight ----- False
Parameter: bert.encoder.layer.1.output.dense.bias ----- False
Parameter: bert.encoder.layer.1.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.1.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.2.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.2.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.2.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.2.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.2.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.2.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.2.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.2.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.2.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.2.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.2.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.2.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.2.output.dense.weight ----- False
Parameter: bert.encoder.layer.2.output.dense.bias ----- False
Parameter: bert.encoder.layer.2.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.2.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.3.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.3.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.3.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.3.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.3.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.3.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.3.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.3.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.3.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.3.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.3.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.3.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.3.output.dense.weight ----- False
Parameter: bert.encoder.layer.3.output.dense.bias ----- False
Parameter: bert.encoder.layer.3.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.3.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.4.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.4.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.4.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.4.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.4.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.4.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.4.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.4.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.4.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.4.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.4.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.4.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.4.output.dense.weight ----- False
Parameter: bert.encoder.layer.4.output.dense.bias ----- False
Parameter: bert.encoder.layer.4.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.4.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.5.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.5.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.5.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.5.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.5.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.5.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.5.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.5.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.5.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.5.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.5.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.5.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.5.output.dense.weight ----- False
Parameter: bert.encoder.layer.5.output.dense.bias ----- False
Parameter: bert.encoder.layer.5.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.5.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.6.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.6.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.6.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.6.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.6.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.6.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.6.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.6.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.6.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.6.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.6.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.6.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.6.output.dense.weight ----- False
Parameter: bert.encoder.layer.6.output.dense.bias ----- False
Parameter: bert.encoder.layer.6.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.6.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.7.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.7.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.7.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.7.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.7.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.7.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.7.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.7.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.7.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.7.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.7.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.7.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.7.output.dense.weight ----- False
Parameter: bert.encoder.layer.7.output.dense.bias ----- False
Parameter: bert.encoder.layer.7.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.7.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.8.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.8.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.8.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.8.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.8.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.8.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.8.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.8.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.8.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.8.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.8.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.8.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.8.output.dense.weight ----- False
Parameter: bert.encoder.layer.8.output.dense.bias ----- False
Parameter: bert.encoder.layer.8.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.8.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.9.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.9.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.9.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.9.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.9.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.9.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.9.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.9.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.9.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.9.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.9.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.9.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.9.output.dense.weight ----- False
Parameter: bert.encoder.layer.9.output.dense.bias ----- False
Parameter: bert.encoder.layer.9.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.9.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.10.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.10.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.10.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.10.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.10.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.10.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.10.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.10.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.10.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.10.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.10.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.10.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.10.output.dense.weight ----- False
Parameter: bert.encoder.layer.10.output.dense.bias ----- False
Parameter: bert.encoder.layer.10.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.10.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.11.attention.self.query.weight ----- False
Parameter: bert.encoder.layer.11.attention.self.query.bias ----- False
Parameter: bert.encoder.layer.11.attention.self.key.weight ----- False
Parameter: bert.encoder.layer.11.attention.self.key.bias ----- False
Parameter: bert.encoder.layer.11.attention.self.value.weight ----- False
Parameter: bert.encoder.layer.11.attention.self.value.bias ----- False
Parameter: bert.encoder.layer.11.attention.output.dense.weight ----- False
Parameter: bert.encoder.layer.11.attention.output.dense.bias ----- False
Parameter: bert.encoder.layer.11.attention.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.11.attention.output.LayerNorm.bias ----- False
Parameter: bert.encoder.layer.11.intermediate.dense.weight ----- False
Parameter: bert.encoder.layer.11.intermediate.dense.bias ----- False
Parameter: bert.encoder.layer.11.output.dense.weight ----- False
Parameter: bert.encoder.layer.11.output.dense.bias ----- False
Parameter: bert.encoder.layer.11.output.LayerNorm.weight ----- False
Parameter: bert.encoder.layer.11.output.LayerNorm.bias ----- False
Parameter: bert.pooler.dense.weight ----- False
Parameter: bert.pooler.dense.bias ----- False
Parameter: classifier.weight ----- True
Parameter: classifier.bias ----- True

3.1 Train and Evalute the completely frozen BERT model

Code
trainer = Trainer(
   model=model,
   args=training_args,
   train_dataset=tokenized_train,
   eval_dataset=tokenized_test,
   tokenizer=tokenizer,
   data_collator=data_collator,
   compute_metrics=compute_metrics,
)


trainer.train()
TrainOutput(global_step=500, training_loss=0.6961300659179688, metrics={'train_runtime': 249.3582, 'train_samples_per_second': 32.082, 'train_steps_per_second': 2.005, 'total_flos': 2084785113806400.0, 'train_loss': 0.6961300659179688, 'epoch': 1.0})
Code
trainer.evaluate()
{'eval_loss': 0.6841261982917786,
 'eval_f1': 0.5947006869479883,
 'eval_runtime': 59.5484,
 'eval_samples_per_second': 33.586,
 'eval_steps_per_second': 2.099,
 'epoch': 1.0}

As expected, the training time is the shortest (249.35 seconds) in this scenario, but the model’s F1 score is only 0.59, the worst performance among all the experiments, as no fine-tuning was done in the main BERT layers.

4 Conclusion

In this project, we demonstrated:

  • Fine-tuning BERT for movie review sentiment classification.

  • Reducing training time with only a slight drop in performance by selectively fine-tuning only part of the model.

Back to top