决策树分类算法学习
决策树分类算法学习
XGBoost算法及其应用
对于模型的保存
个人犯的错:
# 保存模型时应使用XGBoost原生方法 self.xgb_model.save_model("./classifier/model/xgboost_model/xgboost_model.ubj") #直接使用路径对模型进行加载 class ContextAwareClassifier: def __init__(self, xgb_model, bert_model, tokenizer_model, db_path, db_manager: ControlChatHistoryData, history_size): """传入参数:两种模型,分词器,数据库路径,历史记录长度""" # 初始化数据库连接池 self.connection_pool = SQLiteConnectionPool(db_path) # 初始化并列的类 self.xgb_model = xgb_model#这里我直接传入路径 self.bert_model = bert_model self.tokenizer = BertTokenizer.from_pretrained(tokenizer_model) self.db_manager = db_manager self.history = deque(maxlen=history_size) # 设置最大长度,超过最大长度的历史记录会被删除 #在后续步奏中,采用self.xgb_model.predict的方法直接对值进行预测,造成出错。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
* 正确的方法:
* ```python
class ContextAwareClassifier:
def __init__(self, xgb_model, bert_model, tokenizer_model, db_path, db_manager: ControlChatHistoryData, history_size):
"""传入参数:两种模型,分词器,数据库路径,历史记录长度"""
# 初始化数据库连接池
self.connection_pool = SQLiteConnectionPool(db_path) # 初始化并列的类
self.xgb_model = xgb.Booster()#先进行实例化,后加载模型
self.xgb_model.load_model(xgb_model)
self.bert_model = bert_model
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
self.db_manager = db_manager
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Dedsec的博客!