-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathRerankModelSelect.tsx
More file actions
107 lines (99 loc) · 3.19 KB
/
RerankModelSelect.tsx
File metadata and controls
107 lines (99 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import { useMemo, useCallback, useEffect } from 'react';
import { Select, theme } from 'antd';
import { ModelIcon } from '@lobehub/icons';
import { useProviderStore } from '@/stores';
import { parseModelValue, useProviderNameMap } from './ModelSelect';
function isRerankModel(model: { model_id: string; model_type?: string }) {
return model.model_type === 'Rerank' || /rerank|colbert/i.test(model.model_id);
}
function useRerankModelOptions() {
const providers = useProviderStore((s) => s.providers);
const fetchProviders = useProviderStore((s) => s.fetchProviders);
useEffect(() => {
if (providers.length === 0) {
void fetchProviders();
}
}, [fetchProviders, providers.length]);
return useMemo(() => {
return providers
.filter((p) => p.enabled)
.map((p) => {
const rerankModels = p.models.filter((m) => m.enabled && isRerankModel(m));
if (rerankModels.length === 0) return null;
return {
label: (
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
<ModelIcon model={p.name} size={16} type="avatar" />
{p.name}
</span>
),
title: p.name,
options: rerankModels.map((m) => ({
label: m.name,
value: `${p.id}::${m.model_id}`,
modelId: m.model_id,
providerName: p.name,
})),
};
})
.filter((opt): opt is NonNullable<typeof opt> => opt !== null);
}, [providers]);
}
export function RerankModelSelect({
value,
onChange,
placeholder,
allowClear = true,
style,
}: {
value?: string;
onChange: (value: string | undefined) => void;
placeholder?: string;
allowClear?: boolean;
style?: React.CSSProperties;
}) {
const { token } = theme.useToken();
const rerankOptions = useRerankModelOptions();
const providerNameMap = useProviderNameMap();
const optionRender = useCallback(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(option: any) => (
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
<ModelIcon model={option.data?.modelId ?? ''} size={18} type="avatar" />
{option.label}
</span>
),
[],
);
const labelRender = useCallback(
(props: { label?: React.ReactNode; value?: string | number }) => {
const parsed = parseModelValue(String(props.value ?? ''));
if (!parsed) return <span>{props.label}</span>;
const providerName = providerNameMap.get(parsed.providerId) ?? '';
return (
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
<ModelIcon model={parsed.modelId} size={18} type="avatar" />
{props.label}
<span style={{ fontSize: 11, color: token.colorTextSecondary }}>
({providerName})
</span>
</span>
);
},
[providerNameMap, token.colorTextSecondary],
);
return (
<Select
value={value}
onChange={onChange}
placeholder={placeholder}
allowClear={allowClear}
showSearch
optionFilterProp="label"
optionRender={optionRender}
labelRender={labelRender}
options={rerankOptions}
style={style}
/>
);
}