天池新闻推荐代码解读_2
总体分析
两阶段推荐系统架构
1 | 召回层 (Recall) → 排序层 (Ranking) → 重排层 (Reranking) |
召回策略
- ItemCF召回:基于物品相似度的协同过滤
- 相似度计算:改进的余弦相似度 + 时间衰减
- 优点:个性化强,能发现长尾物品
- 热度召回:基于时间窗口的热门推荐
- 24小时时间窗口过滤
- 优点:实时性强,覆盖热门内容
排序模型
- 模型类型:LightGBM排序模型(LGBMRanker)
- 训练方式:Pairwise排序学习
- 特征工程:
- 用户特征:设备、地区、OS等
- 物品特征:类别、字数、创建时间
- 交叉特征:时间差等
技术亮点
- 负采样策略解决类别不平衡
- 时间衰减因子提高推荐实时性
- MRR指标评估推荐质量
- 多路召回融合提高覆盖率
评估指标
- MRR(Mean Reciprocal Rank):衡量推荐排名的质量
- 召回率:衡量召回阶段的效果
- 特征重要性:分析模型决策依据
代码解读
导入库和基础设置
实现功能:
- 导入库:包括系统库和数据分析库(pandas,numpy),还有机器学习库
- 忽略所有警告(需要import warnings)
- 设置路径
1 | # 导入系统库 |
数据加载
1 | # 加载训练集和测试集A,并合并 |
train_click:train的20w用户训练集+testA的5w用户测试集(总共为25w用户训练)
test_click:testB的5w用户测试集(用于真正的测试)
articles:还没弄懂干什么……
基础工具函数
1 | def get_all_click_df(train=True, test=True): |
ItemCF协同过滤召回
其中包含
def get_past_click()
def get_user_item_time(click_df)
- def make_item_time_pair(df)
def itemcf_sim(df)
- def item_based_recommend(user_id, user_item_time_dict, i2i_sim, sim_item_topk, recall_item_num)
1 | def itemcf_recall(topk=10): |
第182~183行:对每个物品的相似度列表进行排序
1 | for i in tqdm(i2i_sim.keys()): |
.items()- 字典方法,返回一个键值对视图(可迭代对象),每个元素是
(key, value)元组。
- 字典方法,返回一个键值对视图(可迭代对象),每个元素是
sorted(..., key=lambda x: x[1], reverse=True)sorted():Python 内置函数,对可迭代对象排序,返回新列表。key=lambda x: x[1]lambda x: x[1]是一个匿名函数,表示“取每个元组的第 1 个元素(即索引为 1 的值)”。- 因为
x是(item_j, sim_score),所以x[1]就是相似度得分。 - 排序将基于相似度数值大小进行。
reverse=True- 表示降序排序(从高到低),即最相似的排在前面。
第204行:分离训练集和测试集的召回结果
1 | tst_recall = recall_df[recall_df['user_id'].isin(test_last_click['user_id'].unique())] |
isin()是 pandas 的方法,用于判断recall_df中每个user_id是否出现在给定的集合中。- 返回一个 布尔 Series(True/False),长度与
recall_df相同。
- 返回一个 布尔 Series(True/False),长度与
recall_df[...]- 使用布尔索引,只保留
user_id属于测试集用户的那些行。
- 使用布尔索引,只保留
- ✅ 结果
tst_recall:只包含在测试集中出现过的用户的召回结果。
分离用户历史点击和最后一次点击
1 | def get_past_click(): |
第8行:train = train_click.sort_values([‘user_id’, ‘click_timestamp’]).reset_index().copy()
.sort_values([‘user_id’, ‘click_timestamp’])
首先按 user_id排序,然后在每个用户内部按 click_timestamp排序
这样可以将每个用户的数据按时间顺序排列
- .reset_index()
- 重置 DataFrame 的索引为默认的整数索引(0, 1, 2, …)
- 除非使用 reset_index(drop=False),否则会丢弃旧索引
- 排序后使用这个功能可以获得干净的索引
- .copy()
- 创建 DataFrame 的深拷贝
- 确保对 train的修改不会影响原始的 train_clickDataFrame
- 防止 pandas 中的 SettingWithCopyWarning 警告
第13~20行:遍历训练集用户
1 | for user_id in tqdm(train['user_id'].unique()): |
user = train[train[‘user_id’] == user_id]
功能:筛选出当前用户的所有行
优缺点:这个写法虽然直观,但效率可能不高,特别是当数据量大时。
改进方法:可以使用groupby + 聚合等
1 | # 方法1:使用 groupby + 聚合(推荐) |
list1.append(row.values.tolist()[0])
- row.values - 将 pandas Series(行)转换为 numpy 数组
- tolist() - 将 numpy 数组转换为 Python 列表,例如:[[1, 101, 1609459200]]
- [0] - 获取列表的第一个元素(也就是唯一的那一行)
第29行:train_past_clicks = train[~train.index.isin(train_indexs)]
train.index.isin(train_indexs)
检查
trainDataFrame 的索引是否在train_indexs列表中返回一个布尔序列,True 表示索引在列表中,False 表示不在
~符号
- 取反操作符(NOT)
- 将 True 变为 False,False 变为 True
- 相当于 “不在列表中”
train[…]
- 使用布尔序列筛选 DataFrame
- 只选择布尔值为 True 的行
第30行:all_click_df = all_click_df.reset_index().drop(columns=[‘index’])
- .reset_index()
- 将 DataFrame 的当前索引重置为默认的整数索引(0, 1, 2, …)
- 原来的索引会变成新的列,默认列名为 ‘index’
- .drop(columns=[‘index’]):
- 删除名为 ‘index’ 的列
- 也就是删除刚刚从索引转换来的那列
- 将结果重新赋值给
all_click_df
第38行:all_click_df = all_click_df.drop_duplicates(([‘user_id’, ‘click_article_id’, ‘click_timestamp’]))
- subset=[‘user_id’, ‘click_article_id’, ‘click_timestamp’]
- 指定判断重复行的列组合
- 只有这三列的值完全相同的行才会被视为重复
- 其他列的值可以不同
- drop_duplicates()默认行为
- 保留第一个出现的重复行
- 删除后续的重复行
- 默认对所有列进行比较(如果不指定
subset)
根据点击时间获取用户的点击文章序列
1 | def get_user_item_time(click_df): |
第14行:list(zip(df[‘click_article_id’], df[‘click_timestamp’]))
zip(...)- Python 内置函数,将两个可迭代对象“配对”。
- 例如:
zip([101, 102], [1609, 1610])→ 生成(101, 1609), (102, 1610)。
list(...)- 将 zip 对象转为实际的列表(因为 zip 返回的是迭代器)。
第17~19行:user_item_time_df = click_df.groupby(‘user_id’)[[‘click_article_id’, ‘click_timestamp’]] \
.apply(lambda x: make_item_time_pair(x)) \
.reset_index().rename(columns={0: ‘item_time_list’})
click_df.groupby('user_id')groupby('user_id')- pandas 的核心方法,按
user_id分组,将同一个用户的点击记录归到一起。 - 返回一个
GroupBy对象。
- pandas 的核心方法,按
[['click_article_id', 'click_timestamp']]- 选择分组后每个组中只保留这两列,减少后续处理的数据量。
.apply(lambda x: make_item_time_pair(x)).apply(func)- 对每个分组(即每个用户的子 DataFrame)应用自定义函数。
- 这里的
lambda x: make_item_time_pair(x)等价于直接传make_item_time_pair(可简写为.apply(make_item_time_pair))。 - 每个用户的结果是一个
[(item, time), ...]列表。
.reset_index()- 把
user_id从索引变回普通列,方便后续操作。
- 把
.rename(columns={0: 'item_time_list'})- 将默认列名
0(因为 apply 返回的是匿名列)重命名为'item_time_list'。
- 将默认列名
第22行:user_item_time_dict = dict(zip(user_item_time_df[‘user_id’], user_item_time_df[‘item_time_list’]))
zip(...)- 将两列(
user_id和item_time_list)一一配对。
- 将两列(
dict(...)- 将配对结果转为字典。
- 例如:{ 1: [(101, 1609), (102, 1610)], 2: [(105, 1620)], … }
计算物品相似度
1 | def itemcf_sim(df): |
第20行:i2i_sim.setdefault(i, {})
如果
i不在i2i_sim中,则初始化为一个空字典{}。等价于:
1
2if i not in i2i_sim:
i2i_sim[i] = {}
第26行:i2i_sim[i].setdefault(j, 0)
- 确保
i2i_sim[i][j]存在,初始为 0。
第29行: i2i_sim[i][j] += 1 / math.log(len(item_time_list) + 1)
len(item_time_list):该用户总共点击了多少个物品(即行为序列长度)。- 为什么叫“时间衰减”?
- 实际上这里不是基于真实时间差,而是基于用户活跃度(点击越多,说明越“泛”,其共现信息越不可靠)。
- 这是一种启发式惩罚:如果一个用户点击了 100 个物品,那么任意两个物品的共现可能只是偶然;而如果只点了 2 个,共现更可能是强关联。
- 因此,用
1 / log(N + 1)作为共现权重,N 越大,权重越小。 +1是为了避免log(0)或log(1)=0导致除零错误。
✅ 举例:
- 用户 A 点了 2 个物品 → 权重 =
1 / log(2+1) ≈ 1 / 1.1 ≈ 0.91- 用户 B 点了 100 个物品 → 权重 =
1 / log(101) ≈ 1 / 4.6 ≈ 0.22→ 同样的共现,在活跃用户那里贡献更小。
第32~35行:归一化(余弦相似度)
1 | i2i_sim_ = i2i_sim.copy() |
目的:将原始共现得分转换为余弦相似度。
标准 ItemCF 余弦相似度公式:
其中:
- ∣N(i)∩N(j)∣:同时点击过 i 和 j 的用户数(这里是加权后的值
wij) - ∣N(i)∣:点击过 i 的用户总数(即
item_cnt[i])
- ∣N(i)∩N(j)∣:同时点击过 i 和 j 的用户数(这里是加权后的值
这里做了什么?
- 分子
wij:是加权共现次数(考虑了用户活跃度) - 分母:item_cnt[i]×item_cnt[j]
- 结果是一个 0~1 之间的相似度分数,可比性更强。
- 分子
⚠️ 注意:严格来说,分母也应使用加权后的物品流行度,但此处简化为原始点击次数,是常见近似做法。
第38行:pickle.dump(i2isim, open(save_path + ‘itemcf_i2i_sim.pkl’, ‘wb’))
pickle.dump()- Python 内置模块
pickle用于序列化对象(把 Python 字典保存到磁盘)。 'wb':以二进制写模式打开文件。
- Python 内置模块
save_path- 应该是函数外部定义的路径变量(代码中未展示,需确保已定义)。
- 例如:
save_path = './model/'
基于商品的召回i2i
1 | def item_based_recommend(user_id, user_item_time_dict, i2i_sim, sim_item_topk, recall_item_num): |
第20/21行:userhist_items = {userid for user_id, in user_hist_items}
- 作用:
- 将用户历史点击的 物品 ID 提取出来,存入一个
set。 - 使用
set是为了O(1) 快速判断某个物品是否已被点击过,避免重复推荐。
- 将用户历史点击的 物品 ID 提取出来,存入一个
- ✅ 示例:
user_hist_items_ = {101, 105, 110}
热度召回
其中包括
def getitem_topk_click(hot_articles, hot_articles_dict, click_time, past_click_articles, k)
1 | def hot_recall(topk=10, train_past_clicks=None, test_last_click=None): |
第22行:articles_copy = articles.rename(columns={‘article_id’: ‘click_article_id’})
- 作用:重命名列以对齐键名
rename(...)的作用- 将
articles中的'article_id'列重命名为'click_article_id'。 - ✅ 目的:让两个表在后续
merge时有相同的连接键(join key)。
- 将
第23/24行:train_click_df = train_click_df.merge(articles_copy, on=’click_article_id’, how=’left’)
merge(...)是 pandas 的核心函数,用于表连接(类似 SQL JOIN)on='click_article_id':- 指定连接键:两个表都必须有这一列。
- 现在
train_click_df和articles_copy都有click_article_id,可以匹配。
how='left':- 表示左连接(left join):
- 保留
train_click_df中的所有行; - 如果某
click_article_id在articles_copy中找不到,则对应的文章字段填NaN; - 不会丢掉任何点击记录。
- 保留
- 表示左连接(left join):
第27行:train_last_click = train_past_clicks.groupby(‘user_id’).agg({‘click_timestamp’: ‘max’}).reset_index()
groupby('user_id')- 将数据按
user_id分组,每个用户的所有点击记录被归到一组。
- 将数据按
.agg({'click_timestamp': 'max'})- 对每个分组,对
'click_timestamp'列应用聚合函数'max'(取最大值)。 - 因为时间戳越大表示越晚,所以
max(click_timestamp)就是该用户最后一次点击的时间。
- 对每个分组,对
.reset_index()- 将
user_id从索引转回普通列。
- 将
第28行:train_last_click_time = train_last_click.set_index(‘user_id’)[‘click_timestamp’].to_dict()
.set_index('user_id')- 将
user_id列设为 DataFrame 的行索引。
- 将
['click_timestamp']- 选取
click_timestamp这一列,得到一个 pandas Series,其索引是user_id,值是时间戳。
- 选取
.to_dict()将 Series 转换为 Python 原生字典:
1
2
3
4
5{
1: 1609466400,
2: 1609470000,
...
}
第54~55行:
1 | train_hot_articles = pd.DataFrame( |
train_click_df['click_article_id']从训练点击日志 DataFrame 中取出
'click_article_id'列(Series)。示例:
1
[101, 105, 101, 110, 105, 101, ...]
.value_counts()- pandas 方法:统计每个唯一值出现的次数,并自动按频次降序排列。
- 返回一个 Series,索引是
article_id,值是点击次数。
.index获取上述 Series 的索引,即文章 ID 列表(已按点击频次从高到低排序)。
示例:
1
Int64Index([101, 105, 110], dtype='int64')
.to_list()- 将 Index 转换为 Python 原生列表:[101, 105, 110]
pd.DataFrame(..., columns=['article_id'])- 用这个列表创建一个新的 DataFrame,列名为
'article_id'。
- 用这个列表创建一个新的 DataFrame,列名为
第56行:train_hot_articles = train_hot_articles.merge(articles).drop(columns=[‘category_id’, ‘words_count’])
train_hot_articles.merge(articles)merge()是 pandas 的表连接函数(类似 SQL 的 JOIN)。- 默认行为:
- 自动根据两个表中同名的列进行连接(这里就是
'article_id')。 - 默认是 inner join(只保留两个表都存在的
article_id)。
- 自动根据两个表中同名的列进行连接(这里就是
- 等价于:train_hot_articles.merge(articles, on=’article_id’, how=’inner’)
召回结果整合
def get_test_recall(itemcf=False, hot=False)
def get_train_recall(itemcf=False, hot=False, train_last_click=None)
1 | def get_test_recall(itemcf=False, hot=False): |
第36行:与真实标签左连接(关键步骤!)
1 | itemcf_train_recall = itemcf_train_recall.merge(train_last_click, on=['user_id', 'article_id'], how='left') |
on=['user_id', 'article_id']:- 连接条件:同一个用户 + 同一个文章 ID
- 即:检查召回的
(user, item)是否出现在train_last_click(真实最后一次点击)中。
how='left':- 保留所有召回结果;
- 如果某
(user, item)不在train_last_click中,则click_timestamp等字段为NaN; - 如果在,则填充真实的
click_timestamp。
结果示例:
| user_id | article_id | score | click_timestamp |
| ———- | ————— | ——- | ————————————————————- |
| 1 | 205 | 0.9 | 1609470000 ← 被召回且是真实点击(正样本) |
| 1 | 301 | 0.8 | NaN ← 被召回但不是真实点击(负样本) |
| 2 | 150 | 0.7 | NaN ← … |✅ 只有当召回的物品恰好是该用户的最后一次点击时,
click_timestamp才非空。
第37~39行:生成二值标签(label)
1 | itemcf_train_recall['label'] = itemcf_train_recall['click_timestamp'].apply(lambda x: 0.0 if np.isnan(x) else 1.0) |
np.isnan(x):判断click_timestamp是否为NaN。逻辑:
- 如果是
NaN→ 说明该召回物品不是用户的真实最后一次点击 →label = 0 - 如果非
NaN→ 说明命中了真实点击 →label = 1
✅ 这样就为每个召回结果打上了 0/1 标签。
- 如果是
第40~42行:计算并打印召回率(Recall)
1 | print('Train ItemCF RECALL:{}%'.format( |
分子:
itemcf_train_recall['label'].value_counts()[1]- 统计
label=1的数量 → 成功召回的用户数(注意:每个用户最多贡献 1 次命中,因为train_last_click每用户只有一个正样本)
✅ 假设用户1的最后一次点击是205,只要205出现在他的召回列表中,就算命中1次。
- 统计
分母:
len(train_last_click['user_id'].unique())train_last_click中的用户总数 → 总测试用户数
公式本质:
📌 这是 Hit Rate @K(或称为 Recall@K)的一种形式,常用于 Top-K 推荐评估。
第59行:train[‘pred_score’] = train[‘pred_score’].fillna(-100) # 填充缺失的预测分数
- 作用:将
trainDataFrame 中pred_score列的所有缺失值(NaN)替换为 -100。 train['pred_score']:选取pred_score列(通常表示模型对某(user, item)的打分或相似度)。.fillna(-100):pandas 方法,将该列中所有NaN(缺失值)替换为-100。- 赋值回原列,确保后续操作不会因 NaN 出错。
第100行:neg_data_user_sample = neg_data.groupby(‘user_id’, group_keys=False).apply(neg_sample_func)
目的:对每个用户的负样本单独进行采样,确保每个用户都有一定数量的负样本被保留下来。
从
neg_data(所有负样本)中:- 按
user_id分组 - 对每个用户的负样本子集,调用
neg_sample_func进行采样(比如最多保留 5 个) - 最终得到一个每个用户都包含少量负样本的新 DataFrame
✅ 核心思想:避免某些活跃用户“垄断”负样本,也防止不活跃用户完全没有负样本,保证训练时每个用户都能贡献梯度。
- 按
.groupby('user_id', group_keys=False)groupby('user_id')- 将
neg_data按user_id分成多个小组。 - 例如:
- 用户 1 的负样本 → Group A
- 用户 2 的负样本 → Group B
- …
- 将
group_keys=False- 作用:禁止在
apply结果中自动添加分组键(user_id)作为多级索引(MultiIndex)。 - 为什么重要?
- 默认
group_keys=True时,apply返回的 DataFrame 会带有(user_id, original_index)的复合索引。 - 这会导致后续
pd.concat()或去重时索引混乱。 - 设为
False后,返回的是干净的普通 DataFrame,索引是连续的或原始的(取决于neg_sample_func)。
- 默认
✅ 最佳实践:在
groupby().apply()后要合并数据时,通常设group_keys=False。- 作用:禁止在
.apply(neg_sample_func)apply()对每一个分组(即每个用户的负样本子集),调用函数
neg_sample_func。相当于循环:
1
2
3
4
5results = []
for user_id, group_df in neg_data.groupby('user_id'):
sampled = neg_sample_func(group_df)
results.append(sampled)
neg_data_user_sample = pd.concat(results, ignore_index=True)
neg_sample_func(group_df):回顾定义的负采样函数1
2
3
4
5def neg_sample_func(group_df):
neg_num = len(group_df)
sample_num = max(int(neg_num * sample_rate), 1) # 至少1个
sample_num = min(sample_num, 5) # 最多5个
return group_df.sample(n=sample_num, replace=False)- 输入:某个用户的全部负样本(DataFrame)
- 逻辑:
- 计算该用户有多少负样本(
neg_num) - 按比例
sample_rate(如 0.001)计算要采多少个 - 限制在
[1, 5]范围内 - 随机无放回抽样
sample_num个(在 Pandas 的DataFrame.sample()方法中,参数replace控制采样时是否允许重复(即“放回”还是“不放回”)。)
- 计算该用户有多少负样本(
- 输出:采样后的子 DataFrame
📌 举例:
- 用户 1 有 1000 个负样本 →
int(1000 * 0.001) = 1→ 采 1 个 - 用户 2 有 3 个负样本 →
int(3 * 0.001) = 0→max(0,1)=1→ 采 1 个 - 用户 3 有 10000 个 →
10→ 但min(10,5)=5→ 采 5 个
第107~108行:对合并后的负样本去重
1 | neg_data_new = neg_data_new.sort_values(['user_id', 'pred_score']) \ |
- 去重:同一个
(user, item)可能在两种采样中都被选中,保留一个(keep='last'按pred_score排序后保留最后一个,即分数较高的?需注意排序逻辑)。 .sort_values(['user_id', 'pred_score'])默认是升序排列,所有后面的ItemCF相似度越高。
⚠️ 注意:
sort_values(['user_id', 'pred_score'])中,若pred_score是 ItemCF 相似度,则分数越高越相关。keep='last'会保留同用户同物品中分数最高的那个——这其实是合理的,因为高分负样本更“难”(Hard Negative)。
排序模型训练与预测
1 | def train_and_predict(itemcf=False, itemcf_topk=10, hot=False, hot_topk=10, offline=True): |
召回阶段
1 | # 召回阶段 |
第8行:train_past_clicks = train_past_clicks.groupby('user_id').agg({'click_timestamp': 'max'})
.agg({'click_timestamp': 'max'})- 作用:聚合。这是紧接在
groupby之后的操作,用于对每个分组(小组DataFrame)应用一个或多个聚合函数,并将每个分组的结果合并成一个新的DataFrame。 - 参数解析:
agg接受一个字典,这个字典定义了 “对哪一列执行什么操作”。{'click_timestamp': 'max'}- 键(Key):
'click_timestamp',表示我们要操作的列名。 - 值(Value):
'max',表示要应用的聚合函数,这里是取最大值。对于时间戳列,最大值就是最近的时间。
- 作用:聚合。这是紧接在
特征工程
1 | # 特征工程 |
模型训练
1 | # 准备训练数据 |
第6行:X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=66)
- 使用
sklearn.model_selection.train_test_split - 80% 训练,20% 测试
random_state=66:保证结果可复现- 注意:这是随机打乱划分,不考虑用户分组(后续需重建 group 信息)
⚠️ 在推荐系统中,更严谨的做法是按用户划分(避免同一用户的数据同时出现在 train/test),但这里采用简单随机划分,适用于某些场景(如样本独立性假设成立)
第11行:g_train = X_train.groupby(['user_id'], as_index=False).count()['label'].values
X_train.groupby(['user_id'], as_index=False)- 按
user_id分组 as_index=False:表示分组后不将user_id设为索引,而是保留在列中(返回普通 DataFrame)。
- 按
.count()- 对每列统计非空值数量(因为
x_train包含user_id,article_id, 特征列,label等)
- 对每列统计非空值数量(因为
['label']- 取
label列的计数值(其实任意列都一样,因为每行都有值)
- 取
.values- 转为 NumPy 数组,形如:
[5, 3, 8, 2, ...],表示:- 用户1 有 5 个样本
- 用户2 有 3 个样本
- …
- 转为 NumPy 数组,形如:
✅ 这个数组
g_train就是 LightGBM 中group参数所需的格式。
第20~34行:初始化LightGBM排序模型
1 | lgb_ranker = lgb.LGBMRanker( |
- 小学习率(0.01) + 大树数量(1000):
→ 配合验证集早停(early stopping),找到最优迭代次数。 min_child_weight=50:
→ 在排序任务中,防止模型对稀疏用户/物品过拟合(因每个 group 样本数可能很少)。- 正则化(L2=1) + 采样(subsample/colsample=0.7):
→ 典型的防过拟合组合,适合高维稀疏特征。
✅ 这是一套稳健、防过拟合、适合工业级排序任务的参数配置。
第37~42行:训练模型
1 | lgb_ranker.fit( |
核心机制:
group参数作用:告诉模型哪些样本属于同一个 query(在推荐中 = 同一个 user)
格式:
g_train是一个数组,表示每个用户的样本数量,例如:1
g_train = [5, 3, 8, ...] # 用户1有5个候选,用户2有3个...
内部行为:
- LightGBM 会按
group切分数据:[样本0~4] → 用户1,[5~7] → 用户2, … - 在每个用户内部计算 pairwise loss(如 LambdaRank),优化 NDCG 等排序指标
❗ 如果没有
group,模型会退化为普通分类/回归,完全失去排序能力!- LightGBM 会按
评估预测
离线评估
1 | # 离线评估 |
第2行:X_off['pred_score'] = lgb_ranker.predict(X_off[lgb_cols], num_iteration=lgb_ranker.best_iteration_)
作用:
- 使用训练好的
lgb_ranker模型对X_off中的样本进行预测。 - 将预测得分存入新列
'pred_score'。
- 使用训练好的
关键点:
X_off[lgb_cols]:只传入模型训练时使用的特征列。num_iteration=lgb_ranker.best_iteration_:使用验证集效果最好时的迭代轮数(需在fit()时启用早停),防止过拟合。
✅ 这是标准做法,确保使用最优模型状态。
第12行:
recall_df['rank'] = recall_df.groupby(['user_id'])['pred_score'].rank(ascending=False, method='first')
groupby(...).rank(ascending=False, method='first')ascending=False:分数越高,排名越靠前(rank=1 是最高分)method='first':当分数相同时,按 DataFrame 中出现顺序定 rank(避免随机性)
✅ 这是 Learning-to-Rank 评估的核心步骤。
第15行:del recall_df['pred_score'], recall_df['label']
- 作用:删除
pred_score和label列,因为生成提交格式时不需要它们。 - 说明:
del df['col']是直接从 DataFrame 中删除列的快捷方式。- 此时
recall_df应只包含:user_id,article_id,rank
第16行:筛选 Top-5 并转为宽表
1 | submit = recall_df[recall_df['rank'] <= 5].set_index(['user_id', 'rank']).unstack(-1).reset_index() |
作用:
- 保留每个用户的 Top-5 推荐;
- 将“长表”(每行一个推荐)转为“宽表”(每行一个用户,5 列为文章 ID)。
函数说明:
recall_df[recall_df['rank'] <= 5]
→ 筛选 Top-5。.set_index(['user_id', 'rank'])
→ 设置双层索引:(user_id, rank).unstack(-1)
→ 将最内层索引rank展开为列,形成 MultiIndex 列:1
2
3
4article_id
rank 1 2 3 ...
user_id
101 205 301 401 ....reset_index()
→ 将user_id从索引变回普通列。
第17行:max_article = int(recall_df['rank'].value_counts().index.max())
第18行:处理 MultiIndex 列名
submit.columns = [int(col) if isinstance(col, int) else col for col in submit.columns.droplevel(0)]
作用:
- 去掉 MultiIndex 的第一层(通常是
'article_id'),只保留数字 rank。 - 将列名转为:
['', 1, 2, 3, 4, 5](其中''是user_id列)
- 去掉 MultiIndex 的第一层(通常是
函数说明:
submit.columns.droplevel(0):去掉 MultiIndex 的第 0 层。- 列表推导式尝试将数字列名转为
int。
列表推导式:
[int(col) if isinstance(col, int) else col for col in ...]- 遍历
droplevel(0)后的每个列名col; - 如果
col已经是int类型,就转成int(其实没变); - 否则保留原样(比如字符串
'user_id'或'')。
⚠️ 注意:这个
int(col)转换对字符串数字无效!
例如:如果col = '1'(字符串),isinstance('1', int)是False,所以不会转成整数 1。所以这行代码并不能保证把字符串数字转为整数,它的实际作用很有限。
- 遍历
第26~33行:计算 MRR 指标
1 | sums = 0 |
- MRR 定义:
- 若真实点击出现在第 k 位,贡献 k1
- MRR = 所有用户贡献的平均值
- 问题分析:
- 效率极低
submit.loc[submit['user_id'] == user_id]是 O(n) 查询,总复杂度 O(n²)- 大数据下非常慢
- 无异常处理
- 若
user_id不在train_last_click中,.values[0]报错
- 若
- 效率极低
在线测试
1 | # 在线预测(生成最终提交文件) |
第28行:tmp = recall_df.groupby('user_id').apply(lambda x: x['rank'].max())
功能:按用户分组,计算每个用户的最大排名
- 对
recall_df按user_id分组; - 对每个用户,找出其
rank列的最大值; - 返回一个 Series,索引是
user_id,值是该用户最大的rank。
- 对
示例:
假设某用户有 7 个候选文章,其
rank为[1, 2, 3, 4, 5, 6, 7],则x['rank'].max()返回7。
若另一用户只有 3 个候选,则返回3。💡 注意:这里的
rank是从 1 开始连续编号的(由.rank(method='first')或cumcount()+1生成),所以 最大 rank = 候选文章数量。因此,
tmp[user_id] == 该用户的候选文章总数。
第29行:assert tmp.min() >= topk 断言所有用户的候选数 ≥ topk
功能:
检查所有用户中最小的候选数量是否 ≥
topk(如 5);如果有任何一个用户的候选数 <
topk,assert会抛出AssertionError,程序中断。
目的:
- 确保后续执行
recall_df[recall_df['rank'] <= topk]时,每个用户都能取出完整的 topk 个推荐; - 避免因某些用户候选不足,导致提交格式中出现过多
-1,或评估指标偏差。
- 确保后续执行
主程序执行
1 | # 执行完整的训练和预测流程 |
存在问题以及可改进点
存在问题
- ItemCF中的时间衰减因子并非与时间有关而是与用户的活跃度有关。
- hot召回中为了获取分离后的训练集和测试集,会重新调用一次Item_CF,导致花费大量时间
- 解决方法:把分离数据集的函数从Item_CF中提出来
改进点
- 可以考虑加召回通道
- 有的用户匹配可以换成groupby
- 例如:user = train_past_clicks.loc[train_past_clicks[‘user_id’] == user_id] # 可以换成groupby 执行更快(取自hot召回)
