# -*- coding: utf-8 -*-
import logging
import dataiku
import numpy as np
from transformers import AutoTokenizer, AutoModel
from project_utils import compute_embeddings, save, load
BATCH_SIZE = 16
id_label = dataiku.get_custom_variables()["id_label"]
text_label = dataiku.get_custom_variables()["text_label"]
df = dataiku.Dataset("data").get_dataframe().set_index(id_label)
embeddings_folder = dataiku.Folder("P4SttKJS")
model_name = dataiku.get_custom_variables()["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
# Compute embeddings
list_paths = embeddings_folder.list_paths_in_partition()
if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths:
ids = load(embeddings_folder, "ids.npy")
emb = load(embeddings_folder, "embeddings.npy")
still_valid = [i for i in range(len(ids)) if ids[i] in df.index]
logging.info(f"{len(still_valid)} embeddings kept")
logging.info(f"{len(ids) - len(still_valid)} embeddings discarded")
ids, emb = ids[still_valid], emb[still_valid, :]
save(embeddings_folder, "embeddings.npy", emb)
save(embeddings_folder, "ids.npy", ids)
df = df[~df.index.isin(ids)]
else:
logging.info(f"No existing embeddings")
if len(df) > 0:
dim_embeddings = int(compute_embeddings(model, tokenizer, [""]).shape[1])
emb = np.empty((len(df), dim_embeddings), dtype=np.float32)
i = 0
while i < len(df):
if i % (100 * BATCH_SIZE) == 0:
logging.info(f"Embedding computation: step {i}")
end = min(i + BATCH_SIZE, len(df))
emb[i:end, :] = compute_embeddings(
model, tokenizer, [df.iloc[j][text_label] for j in range(i, end)]
)
i += BATCH_SIZE
ids = np.array(df.index)
if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths:
previous_ids = load(embeddings_folder, "ids.npy")
ids = np.concatenate((previous_ids, ids))
previous_emb = load(embeddings_folder, "embeddings.npy")
emb = np.concatenate((previous_emb, emb), axis=0)
logging.info(f"{len(df)} embeddings computed")
save(embeddings_folder, "embeddings.npy", emb)
save(embeddings_folder, "ids.npy", ids)
logging.info(f"{len(ids)} embeddings in total")
else:
logging.info(f"No additional embedding computed")
https://gemini.google.com/share/7d8e6a0ff3c5
텍스트 임베딩 로딩 및 계산
Gemini로 생성됨
gemini.google.com
_
반응형