diff --git a/modelcache/adapter/adapter_query.py b/modelcache/adapter/adapter_query.py index e3be30a..5c2afcc 100644 --- a/modelcache/adapter/adapter_query.py +++ b/modelcache/adapter/adapter_query.py @@ -7,11 +7,13 @@ from modelcache.processor.pre import multi_analysis from FlagEmbedding import FlagReranker -USE_RERANKER = True # 如果为 True 则启用 reranker,否则使用原有逻辑 +# USE_RERANKER = True # 如果为 True 则启用 reranker,否则使用原有逻辑 + def adapt_query(cache_data_convert, *args, **kwargs): chat_cache = kwargs.pop("cache_obj", cache) scope = kwargs.pop("scope", None) + use_reranker = kwargs.pop("use_reranker", False) model = scope['model'] if not chat_cache.has_init: raise NotInitError() @@ -76,7 +78,7 @@ def adapt_query(cache_data_convert, *args, **kwargs): if rank_pre < rank_threshold: return - if USE_RERANKER: + if use_reranker: reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=False) for cache_data in cache_data_list: primary_id = cache_data[1] diff --git a/modelcache/manager/vector_data/faiss.py b/modelcache/manager/vector_data/faiss.py index 0f8445c..35c261e 100644 --- a/modelcache/manager/vector_data/faiss.py +++ b/modelcache/manager/vector_data/faiss.py @@ -14,8 +14,9 @@ def __init__(self, index_file_path, dimension, top_k): self._dimension = dimension self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2) self._top_k = top_k - if os.path.isfile(index_file_path): - self._index = faiss.read_index(index_file_path) + self.index_file_path = index_file_path + if os.path.isfile(self.index_file_path): + self._index = faiss.read_index(self.index_file_path) def mul_add(self, datas: List[VectorData], model=None): data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) @@ -54,3 +55,9 @@ def close(self): def count(self): return self._index.ntotal + + def create(self, model=None): + if os.path.isfile(self.index_file_path): + self._index = faiss.read_index(self.index_file_path) + + return 'create_success' diff --git a/modelcache/manager/vector_data/milvus.py b/modelcache/manager/vector_data/milvus.py index 50d6ab1..2c07204 100644 --- a/modelcache/manager/vector_data/milvus.py +++ b/modelcache/manager/vector_data/milvus.py @@ -180,6 +180,14 @@ def rebuild_col(self, model): def rebuild(self, ids=None): # pylint: disable=unused-argument self.col.compact() + def create(self, model=None): + collection_name_model = self.collection_name + '_' + model + if utility.has_collection(collection_name_model, using=self.alias): + return 'already_exists' + else: + self._create_collection(collection_name_model) + return 'create_success' + def flush(self): self.col.flush(_async=True) diff --git a/modelcache/manager/vector_data/redis.py b/modelcache/manager/vector_data/redis.py index afa1088..83b701d 100644 --- a/modelcache/manager/vector_data/redis.py +++ b/modelcache/manager/vector_data/redis.py @@ -71,9 +71,11 @@ def create_index(self, index_name, index_prefix): definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) # create Index - self._client.ft(index_name).create_index( - fields=fields, definition=definition - ) + print('self._client: {}'.format(self._client)) + resp = self._client.ft(index_name).create_index( + fields=fields, definition=definition + ) + print('resp: {}'.format(resp)) return 'create_success' def mul_add(self, datas: List[VectorData], model=None):