import unittest
from typing import List

from pymilvus import FieldSchema, DataType

from app.services.milvus import MilvusClient


class TestCreateCollectionFunction(unittest.TestCase):
    def test_create_collection(self):
        fields: List[FieldSchema] = [
            FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
            FieldSchema(name="key", dtype=DataType.VARCHAR, max_length=256),
            FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512)  # CLIP ViT-B/32的特征维度为512
            # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=768)  # CLIP ViT-L/14@336px的特征维度为512
        ]
        MilvusClient().connect().create_collection("pc3", fields, "PC3 图片向量存储")

    def test_create_index(self):
        index_params = {
            "index_type": "HNSW",
            "metric_type": "IP",  # 内积相似度
            "params": {"M": 40, "efConstruction": 600, "ef": 400}
        }
        MilvusClient().connect().create_index("pc3", "vector", index_params)



if __name__ == '__main__':
    unittest.main()