Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
image_search
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
service
image_search
Commits
ef587af1
Commit
ef587af1
authored
May 24, 2025
by
zhengyaoqiu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
特征提取模块重构
parent
9858b763
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
387 additions
and
46 deletions
+387
-46
feature_extractor.py
app/models/feature_extractor.py
+47
-44
image_search.py
app/models/image_search.py
+224
-0
milvus.py
app/models/milvus.py
+112
-0
requirements.txt
requirements.txt
+3
-1
test_feature_extractor.py
tests/test_feature_extractor.py
+1
-1
No files found.
app/models/feature_extractor.py
View file @
ef587af1
import
threading
import
torch
import
torch
import
clip
import
clip
import
requests
import
requests
from
PIL
import
Image
from
PIL
import
Image
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
logging
class
FeatureExtractor
:
"""
使用CLIP模型提取图像特征向量的工具类
"""
# 类变量,用于存储模型实例
_model
=
None
_preprocess
=
None
_device
=
None
@
classmethod
def
initialize
(
cls
,
model_name
=
"ViT-B/32"
):
"""
初始化CLIP模型
Args:
class
FeatureExtractor
:
model_name (str): CLIP模型名称
__model
=
None
"""
__preprocess
=
None
cls
.
_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
__logger
=
logging
.
getLogger
(
__name__
)
cls
.
_model
,
cls
.
_preprocess
=
clip
.
load
(
model_name
,
device
=
cls
.
_device
)
__instance
=
None
print
(
f
"CLIP model {model_name} loaded on {cls._device}"
)
__lock
=
threading
.
Lock
()
__device
=
"ViT-B/32"
def
__new__
(
cls
,
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
model_name
=
"ViT-B/32"
):
# 第一次检查 - 不带锁
if
cls
.
__instance
is
None
:
# 只有在可能需要创建实例时才获取锁
with
cls
.
__lock
:
# 第二次检查 - 带锁
if
cls
.
__instance
is
None
:
print
(
f
"创建并初始化 CLIP 模型: {model_name} 在设备: {device}"
)
# 创建实例
cls
.
__instance
=
super
()
.
__new__
(
cls
)
# 在这里直接完成初始化
cls
.
__instance
.
__model
,
cls
.
__instance
.
__preprocess
=
clip
.
load
(
model_name
,
device
=
device
)
cls
.
__instance
.
__device
=
device
return
cls
.
__instance
@
staticmethod
@
staticmethod
def
resize_with_padding
(
img
):
def
resize_with_padding
(
img
,
target_size
=
(
224
,
224
)
):
"""
"""
调整图像大小,保持纵横比并添加填充
调整图像大小,保持纵横比并添加填充
Args:
Args:
img (Image.Image): 输入图像
img: 输入图像
target_size: 目标尺寸,默认为(224, 224)
Returns:
Returns:
PIL.Image:
调整大小后的图像
调整大小后的图像
"""
"""
target_size
=
(
224
,
224
)
# 计算调整大小的比例
# 计算调整大小的比例
ratio
=
min
(
target_size
[
0
]
/
img
.
width
,
target_size
[
1
]
/
img
.
height
)
ratio
=
min
(
target_size
[
0
]
/
img
.
width
,
target_size
[
1
]
/
img
.
height
)
new_size
=
(
int
(
img
.
width
*
ratio
),
int
(
img
.
height
*
ratio
))
new_size
=
(
int
(
img
.
width
*
ratio
),
int
(
img
.
height
*
ratio
))
...
@@ -58,62 +64,59 @@ class FeatureExtractor:
...
@@ -58,62 +64,59 @@ class FeatureExtractor:
return
new_img
return
new_img
@
classmethod
def
extract_from_url
(
self
,
image_url
):
def
extract_from_url
(
cls
,
image_url
):
"""
"""
从URL加载图像并提取特征向量
从URL加载图像并提取特征向量
Args:
Args:
image_url
(str)
: 图像URL
image_url: 图像URL
Returns:
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
"""
if
cls
.
_model
is
None
:
cls
.
initialize
()
try
:
try
:
# 下载图片
# 下载图片
response
=
requests
.
get
(
image_url
,
stream
=
True
)
response
=
requests
.
get
(
image_url
,
stream
=
True
,
timeout
=
10
)
response
.
raise_for_status
()
# 确保请求成功
response
.
raise_for_status
()
# 确保请求成功
# 将图片数据转换为 PIL Image 对象
# 将图片数据转换为 PIL Image 对象
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
.
convert
(
"RGB"
)
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
.
convert
(
"RGB"
)
return
cls
.
extract_from_image
(
image
)
return
self
.
extract_from_image
(
image
)
except
requests
.
RequestException
as
e
:
self
.
__logger
.
error
(
f
"Network error when downloading image from {image_url}: {e}"
)
return
None
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error extracting features from URL
: {e}"
)
self
.
__logger
.
error
(
f
"Error extracting features from URL {image_url}
: {e}"
)
return
None
return
None
@
classmethod
def
extract_from_image
(
self
,
img
):
def
extract_from_image
(
cls
,
img
):
"""
"""
从PIL图像对象提取特征向量
从PIL图像对象提取特征向量
Args:
Args:
img
(Image.Image)
: 输入图像
img: 输入图像
Returns:
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
"""
if
cls
.
_model
is
None
:
cls
.
initialize
()
try
:
try
:
# 调整图像大小并添加填充
# 调整图像大小并添加填充
image
=
cls
.
resize_with_padding
(
img
)
image
=
self
.
resize_with_padding
(
img
)
# 预处理并提取特征
# 预处理并提取特征
image_tensor
=
cls
.
_preprocess
(
image
)
.
unsqueeze
(
0
)
.
to
(
cls
.
_device
)
image_tensor
=
self
.
__preprocess
(
image
)
.
unsqueeze
(
0
)
.
to
(
self
.
_
_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
image_features
=
cls
.
_model
.
encode_image
(
image_tensor
)
image_features
=
self
.
_
_model
.
encode_image
(
image_tensor
)
# 归一化特征向量
# 归一化特征向量
image_features
/=
image_features
.
norm
(
dim
=-
1
,
keepdim
=
True
)
image_features
/=
image_features
.
norm
(
dim
=-
1
,
keepdim
=
True
)
return
image_features
.
cpu
()
.
numpy
()
.
flatten
()
return
image_features
.
cpu
()
.
numpy
()
.
flatten
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error extracting features from image: {e}"
)
self
.
__logger
.
error
(
f
"Error extracting features from image: {e}"
)
return
None
return
None
\ No newline at end of file
app/models/image_search.py
0 → 100644
View file @
ef587af1
from
typing
import
List
,
Tuple
,
Optional
,
Union
import
logging
import
numpy
as
np
from
PIL
import
Image
from
pymilvus
import
Collection
# 引入特征提取器和MilvusClient
from
feature_extractor
import
FeatureExtractor
from
milvus
import
MilvusClient
# 假设MilvusClient在milvus_client.py文件中
class
ImageSearch
:
"""
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger
=
logging
.
getLogger
(
__name__
)
_feature_extractor_initialized
=
False
_milvus_client
=
None
@
classmethod
def
_initialize_feature_extractor
(
cls
)
->
None
:
"""初始化特征提取器"""
if
not
cls
.
_feature_extractor_initialized
:
try
:
FeatureExtractor
.
initialize
()
cls
.
_feature_extractor_initialized
=
True
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to initialize feature extractor: {e}"
)
raise
RuntimeError
(
f
"Failed to initialize feature extractor: {e}"
)
@
classmethod
def
get_milvus_client
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
MilvusClient
:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if
cls
.
_milvus_client
is
None
:
cls
.
_milvus_client
=
MilvusClient
(
host
,
port
,
collection_name
)
.
connect
()
return
cls
.
_milvus_client
@
classmethod
def
get_collection
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
Collection
:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
Raises:
RuntimeError: 如果无法获取集合
"""
try
:
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
return
client
.
get_collection
()
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to get collection: {e}"
)
raise
RuntimeError
(
f
"Failed to get collection: {e}"
)
@
classmethod
def
extract_features
(
cls
,
image
:
Union
[
str
,
Image
.
Image
])
->
np
.
ndarray
:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
try
:
# 根据图像类型调用相应的提取方法
if
isinstance
(
image
,
str
):
features
=
FeatureExtractor
.
extract_from_url
(
image
)
elif
isinstance
(
image
,
Image
.
Image
):
features
=
FeatureExtractor
.
extract_from_image
(
image
)
else
:
raise
ValueError
(
f
"Unsupported image type: {type(image)}"
)
if
features
is
None
:
raise
ValueError
(
"Feature extraction returned None"
)
return
features
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Feature extraction failed: {e}"
)
raise
ValueError
(
f
"Failed to extract features: {e}"
)
@
classmethod
def
image_to_image_search
(
cls
,
image
:
Union
[
str
,
Image
.
Image
],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
Tuple
[
str
,
float
]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
try
:
# 提取查询图像的特征
query_embedding
=
cls
.
extract_features
(
image
)
# 获取Milvus客户端并搜索
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
results
=
client
.
search
(
query_embedding
,
limit
=
top_k
)
# 处理结果
if
not
results
or
len
(
results
)
==
0
:
return
[]
# 返回结果
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
[
0
]]
scores
=
[
hit
.
score
for
hit
in
results
[
0
]]
return
list
(
zip
(
product_ids
,
scores
))
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Image search failed: {e}"
)
raise
ValueError
(
f
"Image search failed: {e}"
)
@
classmethod
def
batch_image_search
(
cls
,
images
:
List
[
Union
[
str
,
Image
.
Image
]],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
List
[
Tuple
[
str
,
float
]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
# 批量提取特征
batch_embeddings
=
[]
valid_indices
=
[]
for
i
,
image
in
enumerate
(
images
):
try
:
features
=
cls
.
extract_features
(
image
)
batch_embeddings
.
append
(
features
)
valid_indices
.
append
(
i
)
except
Exception
as
e
:
cls
.
_logger
.
warning
(
f
"Feature extraction failed for image at index {i}: {e}"
)
# 如果没有有效的特征,返回空列表
if
not
batch_embeddings
:
return
[[]
for
_
in
range
(
len
(
images
))]
try
:
# 获取集合
collection
=
cls
.
get_collection
(
host
,
port
,
collection_name
)
collection
.
load
()
# 批量搜索
search_params
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
batch_results
=
collection
.
search
(
data
=
batch_embeddings
,
anns_field
=
"embedding"
,
param
=
search_params
,
limit
=
top_k
,
output_fields
=
[
"image"
,
"product_id"
]
)
# 处理结果
all_results
=
[[]
for
_
in
range
(
len
(
images
))]
for
i
,
results
in
enumerate
(
batch_results
):
original_idx
=
valid_indices
[
i
]
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
]
scores
=
[
hit
.
score
for
hit
in
results
]
all_results
[
original_idx
]
=
list
(
zip
(
product_ids
,
scores
))
return
all_results
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Batch image search failed: {e}"
)
raise
ValueError
(
f
"Batch image search failed: {e}"
)
app/models/milvus.py
0 → 100644
View file @
ef587af1
from
typing
import
Dict
,
List
,
Any
,
Optional
,
Union
import
numpy
as
np
from
pymilvus
import
connections
,
Collection
,
FieldSchema
,
CollectionSchema
,
DataType
,
utility
class
MilvusClient
:
def
__init__
(
self
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
None
:
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
self
.
host
:
str
=
host
self
.
port
:
str
=
port
self
.
collection_name
:
str
=
collection_name
self
.
collection
:
Optional
[
Collection
]
=
None
def
connect
(
self
)
->
'MilvusClient'
:
"""连接到Milvus服务器
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections
.
connect
(
"default"
,
host
=
self
.
host
,
port
=
self
.
port
)
return
self
def
get_collection
(
self
)
->
Collection
:
"""获取或创建集合
Returns:
Collection: Milvus集合对象
"""
# 定义集合结构
fields
:
List
[
FieldSchema
]
=
[
FieldSchema
(
name
=
"image"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
,
is_primary
=
True
,
auto_id
=
False
),
FieldSchema
(
name
=
"product_id"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
),
FieldSchema
(
name
=
"embedding"
,
dtype
=
DataType
.
FLOAT_VECTOR
,
dim
=
512
)
# CLIP ViT-B/32的特征维度为512
]
schema
:
CollectionSchema
=
CollectionSchema
(
fields
,
"图像特征集合"
)
# 创建或获取集合
if
utility
.
has_collection
(
self
.
collection_name
):
self
.
collection
=
Collection
(
name
=
self
.
collection_name
)
else
:
self
.
collection
=
Collection
(
name
=
self
.
collection_name
,
schema
=
schema
)
self
.
_create_index
()
return
self
.
collection
def
_create_index
(
self
)
->
None
:
"""创建索引"""
index_params
:
Dict
[
str
,
Any
]
=
{
"index_type"
:
"HNSW"
,
"metric_type"
:
"IP"
,
# 内积相似度
"params"
:
{
"M"
:
16
,
"efConstruction"
:
200
}
}
if
self
.
collection
is
not
None
:
self
.
collection
.
create_index
(
field_name
=
"embedding"
,
index_params
=
index_params
)
def
search
(
self
,
vector
:
Union
[
List
[
float
],
np
.
ndarray
],
limit
:
int
=
10
)
->
Any
:
"""搜索相似向量
Args:
vector: 查询向量,可以是列表或numpy数组
limit: 返回结果数量
Returns:
查询结果
"""
if
self
.
collection
is
None
:
self
.
get_collection
()
self
.
collection
.
load
()
search_params
:
Dict
[
str
,
Any
]
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
results
=
self
.
collection
.
search
(
data
=
[
vector
],
anns_field
=
"embedding"
,
param
=
search_params
,
limit
=
limit
,
output_fields
=
[
"image"
,
"product_id"
]
)
return
results
def
insert
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
None
:
"""插入数据
Args:
data: 包含image、product_id和embedding的字典列表
每个字典应包含键: "image", "product_id", "embedding"
"""
if
self
.
collection
is
None
:
self
.
get_collection
()
entities
:
List
[
List
[
Any
]]
=
[
[
item
[
"image"
]
for
item
in
data
],
[
item
[
"product_id"
]
for
item
in
data
],
[
item
[
"embedding"
]
for
item
in
data
]
]
self
.
collection
.
insert
(
entities
)
def
drop_collection
(
self
)
->
None
:
"""删除集合"""
if
utility
.
has_collection
(
self
.
collection_name
):
utility
.
drop_collection
(
self
.
collection_name
)
self
.
collection
=
None
def
close
(
self
)
->
None
:
"""关闭连接"""
connections
.
disconnect
(
"default"
)
requirements.txt
View file @
ef587af1
...
@@ -32,4 +32,6 @@ networkx~=3.4.2
...
@@ -32,4 +32,6 @@ networkx~=3.4.2
urllib3
~=2.4.0
urllib3
~=2.4.0
fsspec
~=2025.3.2
fsspec
~=2025.3.2
ujson
~=5.10.0
ujson
~=5.10.0
pandas
~=2.2.3
pandas
~=2.2.3
\ No newline at end of file
pymilvus
~=2.5.9
typing_extensions
~=4.13.2
\ No newline at end of file
tests/test_feature_extractor.py
View file @
ef587af1
...
@@ -6,7 +6,7 @@ from app.models.feature_extractor import FeatureExtractor
...
@@ -6,7 +6,7 @@ from app.models.feature_extractor import FeatureExtractor
class
TestFeatureExtractorFunction
(
unittest
.
TestCase
):
class
TestFeatureExtractorFunction
(
unittest
.
TestCase
):
def
test_feature_extractor
(
self
):
def
test_feature_extractor
(
self
):
url
=
"https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
url
=
"https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
feature
=
FeatureExtractor
.
extract_from_url
(
url
)
feature
=
FeatureExtractor
()
.
extract_from_url
(
url
)
print
(
feature
)
print
(
feature
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment