决策树分类算法学习

XGBoost算法及其应用

对于模型的保存

  • 个人犯的错:

  • # 保存模型时应使用XGBoost原生方法
    self.xgb_model.save_model("./classifier/model/xgboost_model/xgboost_model.ubj") 
    #直接使用路径对模型进行加载
    class ContextAwareClassifier:
        def __init__(self, xgb_model, bert_model, tokenizer_model, db_path, db_manager: ControlChatHistoryData, history_size):
            """传入参数:两种模型,分词器,数据库路径,历史记录长度"""
    
            # 初始化数据库连接池
            self.connection_pool = SQLiteConnectionPool(db_path)  # 初始化并列的类
            self.xgb_model = xgb_model#这里我直接传入路径
            self.bert_model = bert_model
            self.tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
            self.db_manager = db_manager
            self.history = deque(maxlen=history_size)  # 设置最大长度,超过最大长度的历史记录会被删除
            
    #在后续步奏中,采用self.xgb_model.predict的方法直接对值进行预测,造成出错。
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15

    * 正确的方法:

    * ```python
    class ContextAwareClassifier:
    def __init__(self, xgb_model, bert_model, tokenizer_model, db_path, db_manager: ControlChatHistoryData, history_size):
    """传入参数:两种模型,分词器,数据库路径,历史记录长度"""

    # 初始化数据库连接池
    self.connection_pool = SQLiteConnectionPool(db_path) # 初始化并列的类
    self.xgb_model = xgb.Booster()#先进行实例化,后加载模型
    self.xgb_model.load_model(xgb_model)
    self.bert_model = bert_model
    self.tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
    self.db_manager = db_manager