# Initialization

In [None]:
import os
os.chdir('../../vlm_toolbox/')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import gc
import warnings

import torch
from transformers import Trainer, TrainingArguments

from config.annotations import AnnotationsConfig
from config.enums import (
    CLIPBackbones,
    DataStatus,
    Granularities,
    ImageDatasets,
    Modalities,
    Setups,
    Sources,
    Stages,
    Trainers,
)
from config.image_datasets import ImageDatasetConfig
from config.model import ModelConfigManager
from config.setup import Setup
from config.train import TrainingArgumentsConfig
from data.data_access.image_factory import ImageHandlerFactory
from data.data_access.label_factory import LabelHandleFactory
from data.data_access.text_factory import TextHandlerFactory
from data.data_collate.factory import DataCollatorFactory
from metric.accuracy import AccuracyMetricEvaluator
from metric.visualization.accuracy import plot_model_accuracy
from model.vlm_factory import VLMFactory

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# Config

In [None]:
CHECKPOINT_PATH = None

### Device

In [None]:
DEVICE_TYPE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_TYPE)
DEVICE

### Data Type

In [None]:
ENABLE_FP16 = True
dtype = torch.float16 if ENABLE_FP16 else torch.float32

### Training

In [None]:
PREPROCESS_BATCH_SIZE = 512
BATCH_SIZE = 512
RANDOM_STATE = 42

### Setup

In [None]:
setup = Setup(
    setup=Setups.FEW_SHOT,
    metric_for_best_model=AccuracyMetricEvaluator.get_main_metric_name(),
    dataset_name=ImageDatasets.IMAGENET_1K,
    backbone_name=CLIPBackbones.CLIP_VIT_B_16,
    trainer_name=Trainers.CLIP,
    source=Sources.OPEN_AI,
    granularity=Granularities.FINE,
    validation_batch_size=BATCH_SIZE,
    n_shots=16,
    k_ways=53,
    is_supervised=True,
)
setup

### Data

In [None]:
IMAGE_MODALITY_TYPE = DataStatus.EMBEDDING
SPLIT = Stages.EVAL

In [None]:
annotations_config = AnnotationsConfig.get_config(dataset_name=setup.dataset_name)
image_dataset_config = ImageDatasetConfig.get_config(setup, split=SPLIT, data_type=IMAGE_MODALITY_TYPE)

# Dataset Loading

In [None]:
image_dataset_handler = (
    ImageHandlerFactory.create_from_config(
        Modalities.M1,
        image_dataset_config,
    )
).show()

In [None]:
def experiment(
    label_handler,
    image_dataset_handler=image_dataset_handler,
    setup=setup,
    checkpoint_path=CHECKPOINT_PATH,
    batch_size=BATCH_SIZE,
    split=SPLIT,
    device_type=DEVICE_TYPE,
    fp16=ENABLE_FP16,
    
):
    class_id_label_id_adj_matrix = label_handler.get_class_id_label_id_adj_matrix()
    text_dataset_handler = TextHandlerFactory.create_from_df(
        Modalities.M2,
        split,
        label_handler.get_prompts_df(),
    )
    model_config = ModelConfigManager.get_config(
        backbone_name=setup.backbone_name,
        source=setup.source,
        context_initialization=None,
        trainer_name=setup.trainer_name,
        labels=None if not setup.is_soft else label_handler.get_labels(),
        label_id_prompt_id_mapping=None if not setup.is_soft else label_handler.get_label_id_prompt_id_mapping(),
    )
    
    vlm = VLMFactory.from_pretrained(model_config=model_config).to(torch.device(device_type)).eval()
    if checkpoint_path:
        vlm.load_state_dict(torch.load(checkpoint_path))

    for dataset_handler in [image_dataset_handler, text_dataset_handler]:
        if not dataset_handler.is_embedded():
            with torch.no_grad(), torch.autocast(device_type=device_type, dtype=dtype):
                vlm.eval()
                dataset_handler.to_embedding(
                    vlm.get_embedding_fn_for_modality(dataset_handler.modality),
                    batch_size=PREPROCESS_BATCH_SIZE,
                )
    text_dataset_handler = text_dataset_handler.to_prototypical_representation()
    data_collator = DataCollatorFactory.create_multimodal_collator(
        class_id_label_id_adj_matrix,
        image_dataset_handler,
        text_dataset_handler,
        is_classification=True,
    )
    evaluation_args= TrainingArguments(
        per_device_train_batch_size=setup.validation_batch_size,
        per_device_eval_batch_size=setup.validation_batch_size,
        metric_for_best_model=setup.metric_for_best_model,
        label_names=data_collator.get_label_names(),
        **TrainingArgumentsConfig.get_config(),
    )
   
    metric_evaluator = AccuracyMetricEvaluator(
        label_handler,
        temperature=vlm.get_logit_scale().detach().cpu(),
    )
    image_dataset = image_dataset_handler.get_dataset(return_pt=True)
    trainer = Trainer(
        model=vlm,
        args=evaluation_args,
        train_dataset=image_dataset,
        eval_dataset=image_dataset,
        data_collator=data_collator,
        compute_metrics=metric_evaluator,
    )
    trainer.predict(image_dataset)
    overall_accuracy_df = metric_evaluator.calculate_overall_accuracy()
    overall_accuracy_df['model_name'] = setup.trainer_name
    plot_model_accuracy(overall_accuracy_df, title=f'Accuracy Performance On {label_handler.label_column}')
    return overall_accuracy_df

## Standard Evaluation

In [None]:
label_handler = (
    LabelHandleFactory.create_from_config(annotations_config)
    .config_prompts()
).show()
flush()
experiment(label_handler)

## Evaluate on Coarse Labels

In [None]:
label_handler = (
    LabelHandleFactory.create_from_config(annotations_config)
    .update_label('coarse')
    .config_prompts()
).show()
flush()
experiment(label_handler)

##  Evaluate on a Direct Parent

In [None]:
def get_direct_parent_label(row):
    parents_list = row['parents']
    return parents_list[0].replace('_', ' ') if len(parents_list) else None

In [None]:
label_handler = (
    LabelHandleFactory.create_from_config(annotations_config)
    .add_column_to_metadata(get_direct_parent_label, 'direct_parent_label')
    .update_label('direct_parent_label')
    .config_prompts()
).show()
flush()
experiment(label_handler)

##  Evaluate on Direct Children

In [None]:
def get_subclasses_labels(row):
    children_list = row['children']
    return list(set(row['children'] + [row['class_label']]))

In [None]:
label_handler = (
    LabelHandleFactory.create_from_config(annotations_config)
    .add_column_to_metadata(get_subclasses_labels, 'child_label', flatten=True)
    .config_prompts(apply_on_col='child_label')
).show()
flush()
experiment(label_handler)