第一关:随机森林

from sklearn.ensemble import RandomForestClassifier

def data_classification(train_data, train_label, test_data):
    '''
    使用随机森林对训练集数据进行训练,并对测试集数据进行预测,并返回预测结果
    :param train_data: 训练集数据,类型为ndarray
    :param train_label: 训练集标签,类型为ndarrayfff
    :param test_data: 测试集数据,类型为ndarray
    :return: 分类结果
    '''
    #********* Begin *********#
    rfc = RandomForestClassifier()
    rfc.fit(train_data,train_label)
    return rfc.predict(test_data)
    #********* End *********#

第二关:手写数字识别

from sklearn.ensemble import RandomForestClassifier

def digit_predict(train_image, train_label, test_image):
    '''
    实现功能:训练模型并输出预测结果
    :param train_image: 包含多条训练样本的样本集,类型为ndarray,shape为[-1, 8, 8]
    :param train_label: 包含多条训练样本标签的标签集,类型为ndarray
    :param test_image: 包含多条测试样本的测试集,类型为ndarry
    :return: test_image对应的预测标签,类型为ndarray
    '''

    #************* Begin ************#
    rfc = RandomForestClassifier(n_estimators=500)
    rfc.fit(train_image,train_label)
    return rfc.predict(test_image)
    #************* End **************#
最后修改:2021 年 07 月 01 日
如果觉得我的文章对你有用,请随意赞赏