# -*- coding: utf-8 -*- import logging import dataiku import numpy as np from transformers import AutoTokenizer, AutoModel from project_utils import compute_embeddings, save, load BATCH_SIZE = 16 id_label = dataiku.get_custom_variables()["id_label"] text_label = dataiku.get_custom_variables()["text_label"] df = dataiku.Dataset("data").get_dataframe().set_index(id_label) embeddings_folder = dataiku.Folder("P4SttKJS") model_name = dataiku.get_custom_variables()["model_name"] tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) model.eval() # Compute embeddings list_paths = embeddings_folder.list_paths_in_partition() if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths: ids = load(embeddings_folder, "ids.npy") emb = load(embeddings_folder, "embeddings.npy") still_valid = [i for i in range(len(ids)) if ids[i] in df.index] logging.info(f"{len(still_valid)} embeddings kept") logging.info(f"{len(ids) - len(still_valid)} embeddings discarded") ids, emb = ids[still_valid], emb[still_valid, :] save(embeddings_folder, "embeddings.npy", emb) save(embeddings_folder, "ids.npy", ids) df = df[~df.index.isin(ids)] else: logging.info(f"No existing embeddings") if len(df) > 0: dim_embeddings = int(compute_embeddings(model, tokenizer, [""]).shape[1]) emb = np.empty((len(df), dim_embeddings), dtype=np.float32) i = 0 while i < len(df): if i % (100 * BATCH_SIZE) == 0: logging.info(f"Embedding computation: step {i}") end = min(i + BATCH_SIZE, len(df)) emb[i:end, :] = compute_embeddings( model, tokenizer, [df.iloc[j][text_label] for j in range(i, end)] ) i += BATCH_SIZE ids = np.array(df.index) if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths: previous_ids = load(embeddings_folder, "ids.npy") ids = np.concatenate((previous_ids, ids)) previous_emb = load(embeddings_folder, "embeddings.npy") emb = np.concatenate((previous_emb, emb), axis=0) logging.info(f"{len(df)} embeddings computed") save(embeddings_folder, "embeddings.npy", emb) save(embeddings_folder, "ids.npy", ids) logging.info(f"{len(ids)} embeddings in total") else: logging.info(f"No additional embedding computed")

https://gemini.google.com/share/7d8e6a0ff3c5
텍스트 임베딩 로딩 및 계산
Gemini로 생성됨
gemini.google.com
_
반응형
광고
광고