# Imports

In [None]:
import os
os.chdir('../../vlm_toolbox/')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import gc
import warnings

import holoviews as hv
import pandas as pd
import torch
from cuml import TSNE
from plotly import graph_objects as go

from config.annotations import AnnotationsConfig
from config.enums import (
    CLIPBackbones,
    DataStatus,
    ImageDatasets,
    Modalities,
    ModalityType,
    ModelType,
    Setups,
    Stages,
    Trainers,
)
from config.image_datasets import ImageDatasetConfig
from config.model import ModelConfigManager
from config.path import VISUALIZATIONS_ROOT_DIR
from config.setup import Setup
from data.data_access.image_factory import ImageHandlerFactory
from data.data_access.label_factory import LabelHandleFactory
from data.data_access.text_factory import TextHandlerFactory
from model.vlm_factory import VLMFactory
from util.color import generate_diverse_colors

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
warnings.filterwarnings('ignore')
hv.extension('bokeh', 'matplotlib')

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# Config

In [None]:
CHECKPOINT_PATH = '/home/alireza/novel_coop/few_shot/imagenet1k/clip_vit_b_16/open_ai/coop/16_shots/default/novel/pytorch_model.bin'
OUTPUT_DIR = VISUALIZATIONS_ROOT_DIR + 'embeds_tsne/'

### Device

In [None]:
DEVICE_TYPE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_TYPE)
DEVICE

### Training

In [None]:
PREPROCESS_BATCH_SIZE = 512
RANDOM_STATE = 42

### Setup

In [None]:
setup = Setup(
    dataset_name=ImageDatasets.IMAGENET_1K,
    backbone_name=CLIPBackbones.CLIP_VIT_B_16,
    trainer_name=Trainers.COOP,
    model_type=ModelType.FEW_SHOT,
    setup_type=Setups.FULL,
    num_epochs=200,
    train_batch_size=1000,
    eval_batch_size=1024,
    validation_size=0.15,
    n_shots=16,
    coarse_column_name='coarse',
    # annotations_key_value_criteria={'kingdom': ['Animalia']},
    # top_k=67,
    # model_checkpoint_path='/home/alireza/io/model/few_shot/imagenet1k/clip_vit_b_16/open_ai/coop/16_shots/coarse/novel/pytorch_model.binpytorch_model.bin'
    # model_checkpoint_path='/home/alireza/io/model/zero_shot/imagenet1k/clip_vit_b_16/open_ai/coop/16_shots/default/pytorch_model.bin'
)
setup

### Data

In [None]:
IMAGE_MODALITY_TYPE = DataStatus.EMBEDDING
SPLIT = Stages.EVAL

In [None]:
annotations_config = AnnotationsConfig.get_config(dataset_name=setup.dataset_name)
image_dataset_config = ImageDatasetConfig.get_config(
    setup,
    split=SPLIT,
    data_type=IMAGE_MODALITY_TYPE,
)

In [None]:
SOFT_PROMPT_GROUP = None

# Utils

In [None]:
def export_embeddings_tsne_visualization(visualization_df, setup, modality_types=ModalityType.get_values(), hue_col_name='coarse', directory=OUTPUT_DIR, show_plot=False):
    if not os.path.exists(directory):
        os.makedirs(directory)

    categories = visualization_df[hue_col_name].unique()
    colors = generate_diverse_colors(len(categories))
    color_map = {category: c for category, c in zip(categories, colors)}

    traces = []
    trace_types = []
    for category in categories:
        df_filtered = visualization_df[visualization_df[hue_col_name] == category]
        for modality in modality_types:
            df_modality = df_filtered[df_filtered['type'] == modality]
            hover_text = df_modality.apply(lambda row: f"label: {row['label']}<br>label_id: {row['label_id']}", axis=1)
            marker_size = 5 if modality == ModalityType.TEXT else 2
            opacity = 0.8 if modality == ModalityType.TEXT else 0.6

            trace = go.Scatter(
                x=df_modality['x'],
                y=df_modality['y'],
                mode='markers',
                marker=dict(size=marker_size, opacity=opacity, color=color_map[category]),
                name=f"{category} - {modality}",
                legendgroup=category,
                text=hover_text,
                hoverinfo='text+name'
            )
            traces.append(trace)
            trace_types.append(modality)

    updatemenus = [
        {
            "type": "buttons",
            "buttons": [
                {"label": "All", "method": "update", "args": [{"opacity": [0.6 if modality == ModalityType.IMAGE else 0.8 for modality in trace_types]}]},
                {"label": "Image Only", "method": "update", "args": [{"opacity": [0.6 if modality == ModalityType.IMAGE else 0 for modality in trace_types]}]},
                {"label": "Text Only", "method": "update", "args": [{"opacity": [0.8 if modality == ModalityType.TEXT else 0 for modality in trace_types]}]}
            ],
           
            "direction": "down",
            "showactive": True,
        },
        {
            "type":"buttons",
            "buttons": [
                {"label": "Select All", "method": "update", "args": [{"visible": True}]},
                {"label": "Hide All", "method": "update", "args": [{"visible": "legendonly"}]}
            ],
            "direction": "left",
            "showactive": True,
            "x": 0.5,
            "xanchor": 'center',
            "y": 1.1,
            "yanchor": 'top'
        }
    ]

    layout = go.Layout(
        title='Combined t-SNE Visualization',
        showlegend=True,
        xaxis=dict(title='t-SNE Dimension 1'),
        yaxis=dict(title='t-SNE Dimension 2'),
        updatemenus=updatemenus
    )

    fig = go.Figure(data=traces, layout=layout)
    fig.update_layout(
        legend_title_text=hue_col_name,
        legend={'itemsizing': 'constant', 'groupclick': "toggleitem"}
    )

    file_name = f'{setup.backbone_name}_{setup.trainer_name}_image_text_embeds_tsne.html'
    full_path = os.path.join(directory, file_name)
    file_index = 1
    
    while os.path.exists(full_path):
        base_name, extension = os.path.splitext(file_name)
        new_file_name = f"{base_name}_{file_index}{extension}"
        full_path = os.path.join(directory, new_file_name)
        file_index += 1

    fig.write_html(full_path)
    print(f'saved in {full_path}')
    if show_plot:
        fig.update_layout(
            width=800,
            height=600,
        )
        fig.show()

def create_visualization_df(label_handler, image_embeds_tsne, image_embeds_id, text_embeds_tsne, text_embeds_id, coarse_col_name='coarse'):
    mapping = label_handler.get_mapping('class_id', 'label_id').int().numpy()
    df_images = pd.DataFrame(image_embeds_tsne, columns=['x', 'y'])
    if len(image_embeds_id):
        df_images['label_id'] = image_embeds_id
        df_images['label_id'] = df_images['label_id'].apply(lambda class_id: mapping[class_id])
    df_images['type'] = ModalityType.IMAGE
    df_texts = pd.DataFrame(text_embeds_tsne, columns=['x', 'y'])
    if len(text_embeds_id):
        df_texts['label_id'] = text_embeds_id
    df_texts['type'] = ModalityType.TEXT
    df_combined = pd.concat([df_images, df_texts], ignore_index=True).reset_index(drop=True)
    df_combined = (
        df_combined
        .merge(
            label_handler.labels_df[['label_id', coarse_col_name, 'label']].drop_duplicates(),
            on='label_id',
            how='left',
        )
    )
    return df_combined

# Labels Loading

In [None]:
label_handler = (
    LabelHandleFactory.create_from_config(annotations_config)
    .set_prompt_mode(is_soft=setup.get_is_soft())
    .config_prompts()
).show()

labels = label_handler.get_labels()
labels_df = label_handler.get_labels_df()
prompts_df = label_handler.get_prompts_df()
class_ids = label_handler.get_class_ids()
class_id_label_id_adj_matrix = label_handler.get_class_id_label_id_adj_matrix()
label_id_prompt_id_mapping = label_handler.get_label_id_prompt_id_mapping()

classes_df = label_handler.get_classes_df()

# Model Loading

In [None]:
model_config = ModelConfigManager.get_config(
    backbone_name=setup.backbone_name,
    source=setup.source,
    context_initialization=None,
    trainer_name=setup.trainer_name,
    labels=labels,
    label_id_prompt_id_mapping=label_id_prompt_id_mapping,
)

vlm = VLMFactory.from_pretrained(model_config=model_config).to(DEVICE).eval()
# if CHECKPOINT_PATH:
#     vlm.load_state_dict(torch.load(CHECKPOINT_PATH))
vlm.show()

# Dataset Loading

In [None]:
image_dataset_handler = (
    ImageHandlerFactory.create_from_config(
        key=Modalities.M1,
        stage='validation',
        dataset_config=image_dataset_config,
        to_keep_ids=class_ids,
    )
    # .to_few_shot_dataset(16)
).show()

In [None]:
text_dataset_handler = TextHandlerFactory.create_from_df(
    Modalities.M2,
    SPLIT,
    prompts_df,
    annotations_config,
).show()

### Pre-compute Features

In [None]:
for dataset_handler in [image_dataset_handler, text_dataset_handler]:
    if not dataset_handler.is_embedded():
        with torch.no_grad(), torch.autocast(device_type=DEVICE_TYPE, dtype=torch.float16):
            vlm.eval()
            dataset_handler.to_embedding(
                vlm.get_embedding_fn_for_modality(dataset_handler.modality),
                batch_size=PREPROCESS_BATCH_SIZE,
            )
flush()

# T-SNE

In [None]:
text_dataset, image_dataset = text_dataset_handler.get_dataset(), image_dataset_handler.get_dataset()

In [None]:
image_embeds = image_dataset['image_embeds']
text_embeds = text_dataset['text_embeds']

image_embeds /= image_embeds.norm(dim=-1, keepdim=True)
text_embeds /= text_embeds.norm(dim=-1, keepdim=True)

image_embeds_id = image_dataset['class_id'].int().numpy()
text_embeds_id = text_dataset['label_id'].int().numpy()
all_embeds.shape

In [None]:
def perform_dimension_reduction(image_embeds, text_embeds, image_ids, text_ids, separate=False):
    if separate:
        text_embeds_tsne = TSNE(n_components=2, metric='cosine', init='pca').fit_transform(text_embeds.numpy())
        image_embeds_tsne = TSNE(n_components=2, metric='cosine', init='pca').fit_transform(image_embeds.numpy())
    else:
        all_embeds = torch.concatenate((text_embeds, image_embeds), dim=0).numpy()
        all_embeds_tsne = TSNE(n_components=2, metric='cosine', init='pca', n_neighbors=num_neighbors).fit_transform(all_embeds)
        image_embeds_tsne = all_embeds_tsne[:image_ids.shape[0]]
        text_embeds_tsne = all_embeds_tsne[image_ids.shape[0]:]
    return text_embeds_tsne, image_embeds_tsne

In [None]:
# n text_embeds_tsne, image_embeds_tsne = um_neighbors = len(image_dataset) // len(image_dataset.unique('class_id')) + 1
text_embeds_tsne, image_embeds_tsne = perform_dimension_reduction(image_embeds, text_embeds, image_embeds_id, text_embeds_id, separate=True)

In [None]:
visualization_df = create_visualization_df(
    label_handler,
    image_embeds_tsne,
    image_embeds_id,
    text_embeds_tsne,
    text_embeds_id,
    coarse_col_name='coarse',
)
visualization_df.sample(frac=0.1).head()

In [None]:
export_embeddings_tsne_visualization(visualization_df, modality_types=[ModalityType.TEXT], hue_col_name='coarse', setup=setup, show_plot=False)