• Implement TensorFlow's next_batch for own data


    The version of numpy data

    import numpy as np
    
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle indexe
                self._data = self.data[idx]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data)
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data[start:self._num_examples]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data[idx_update]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data[start:end]
                return np.concatenate((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    dataset = Dataset(np.arange(0, 10))
    for i in range(10):
        print(dataset.next_batch(6))
    print(dataset.data)
    

    The version of pandas data

    import numpy as np
    import pandas as pd
    class Dataset:
        def __init__(self, data):
            self._index_in_epoch = 0
            self._epochs_completed = 0
            self._data = data
            self._num_examples = data.shape[0]
            pass
    
        @property
        def data(self):
            return self._data
    
        def next_batch(self, batch_size, shuffle=True):
            start = self._index_in_epoch
            if start == 0 and self._epochs_completed == 0:
                idx = np.arange(0, self._num_examples)
                np.random.shuffle(idx)  # shuffle index
                self._data = self.data.iloc[idx,:]  # get the shuffled data
    
            # go to the data of next batch
            if start + batch_size > self._num_examples:
                '''
                note: when start  == self._num_examples, data_rest_part = np.array([])
                '''
                self._epochs_completed += 1
                # print(self.data) # this is for debug
                rest_num_examples = self._num_examples - start
                data_rest_part = self.data.iloc[start:self._num_examples,:]
                idx_update = np.arange(0, self._num_examples)
                np.random.shuffle(idx_update)
                self._data = self.data.iloc[idx_update,:]  # get another shuffled data
    
                start = 0
                self._index_in_epoch = batch_size - rest_num_examples
                end = self._index_in_epoch
                data_new_part = self._data.iloc[start:end,:]
                return pd.concat((data_rest_part, data_new_part), axis=0)
            else:
                self._index_in_epoch += batch_size
                end = self._index_in_epoch
                return self._data[start:end]
    
    df = pd.DataFrame()
    df['a']=np.arange(10)
    df['b']=np.arange(10)*10
    dataset = Dataset(df)
    for i in range(10):
        print(dataset.next_batch(5))
    print(dataset.data)
    
  • 相关阅读:
    Android:res之selector背景选择器
    工作备份 build.gradle
    Android studio听云接入另外一种方式
    自由开发者_免费可商用的图片资源推荐
    Duplicate files copied in APK META-INF/LICENSE.txt
    模仿九宫格拼音输入法,根据输入的数字键,形成对应的汉字拼音
    Map转Bean小工具
    验证身份证是否合法算法
    jqzoom插件图片放大功能的一些BUG
    外层div高度不随内层div高度改变的解决办法
  • 原文地址:https://www.cnblogs.com/ZeroTensor/p/10394989.html
Copyright © 2020-2023  润新知