• baselines算法库logger.py模块分析


    baselines根目录下logger.py模块代码:

    import os
    import sys
    import shutil
    import os.path as osp
    import json
    import time
    import datetime
    import tempfile
    from collections import defaultdict
    from contextlib import contextmanager
    
    DEBUG = 10
    INFO = 20
    WARN = 30
    ERROR = 40
    
    DISABLED = 50
    
    class KVWriter(object):
        def writekvs(self, kvs):
            raise NotImplementedError
    
    class SeqWriter(object):
        def writeseq(self, seq):
            raise NotImplementedError
    
    class HumanOutputFormat(KVWriter, SeqWriter):
        def __init__(self, filename_or_file):
            if isinstance(filename_or_file, str):
                self.file = open(filename_or_file, 'wt')
                self.own_file = True
            else:
                assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
                self.file = filename_or_file
                self.own_file = False
    
        def writekvs(self, kvs):
            # Create strings for printing
            key2str = {}
            for (key, val) in sorted(kvs.items()):
                if hasattr(val, '__float__'):
                    valstr = '%-8.3g' % val
                else:
                    valstr = str(val)
                key2str[self._truncate(key)] = self._truncate(valstr)
    
            # Find max widths
            if len(key2str) == 0:
                print('WARNING: tried to write empty key-value dict')
                return
            else:
                keywidth = max(map(len, key2str.keys()))
                valwidth = max(map(len, key2str.values()))
    
            # Write out the data
            dashes = '-' * (keywidth + valwidth + 7)
            lines = [dashes]
            for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
                lines.append('| %s%s | %s%s |' % (
                    key,
                    ' ' * (keywidth - len(key)),
                    val,
                    ' ' * (valwidth - len(val)),
                ))
            lines.append(dashes)
            self.file.write('\n'.join(lines) + '\n')
    
            # Flush the output to the file
            self.file.flush()
    
        def _truncate(self, s):
            maxlen = 30
            return s[:maxlen-3] + '...' if len(s) > maxlen else s
    
        def writeseq(self, seq):
            seq = list(seq)
            for (i, elem) in enumerate(seq):
                self.file.write(elem)
                if i < len(seq) - 1: # add space unless this is the last one
                    self.file.write(' ')
            self.file.write('\n')
            self.file.flush()
    
        def close(self):
            if self.own_file:
                self.file.close()
    
    class JSONOutputFormat(KVWriter):
        def __init__(self, filename):
            self.file = open(filename, 'wt')
    
        def writekvs(self, kvs):
            for k, v in sorted(kvs.items()):
                if hasattr(v, 'dtype'):
                    kvs[k] = float(v)
            self.file.write(json.dumps(kvs) + '\n')
            self.file.flush()
    
        def close(self):
            self.file.close()
    
    class CSVOutputFormat(KVWriter):
        def __init__(self, filename):
            self.file = open(filename, 'w+t')
            self.keys = []
            self.sep = ','
    
        def writekvs(self, kvs):
            # Add our current row to the history
            extra_keys = list(kvs.keys() - self.keys)
            extra_keys.sort()
            if extra_keys:
                self.keys.extend(extra_keys)
                self.file.seek(0)
                lines = self.file.readlines()
                self.file.seek(0)
                for (i, k) in enumerate(self.keys):
                    if i > 0:
                        self.file.write(',')
                    self.file.write(k)
                self.file.write('\n')
                for line in lines[1:]:
                    self.file.write(line[:-1])
                    self.file.write(self.sep * len(extra_keys))
                    self.file.write('\n')
            for (i, k) in enumerate(self.keys):
                if i > 0:
                    self.file.write(',')
                v = kvs.get(k)
                if v is not None:
                    self.file.write(str(v))
            self.file.write('\n')
            self.file.flush()
    
        def close(self):
            self.file.close()
    
    
    class TensorBoardOutputFormat(KVWriter):
        """
        Dumps key/value pairs into TensorBoard's numeric format.
        """
        def __init__(self, dir):
            os.makedirs(dir, exist_ok=True)
            self.dir = dir
            self.step = 1
            prefix = 'events'
            path = osp.join(osp.abspath(dir), prefix)
            import tensorflow as tf
            from tensorflow.python import pywrap_tensorflow
            from tensorflow.core.util import event_pb2
            from tensorflow.python.util import compat
            self.tf = tf
            self.event_pb2 = event_pb2
            self.pywrap_tensorflow = pywrap_tensorflow
            self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
    
        def writekvs(self, kvs):
            def summary_val(k, v):
                kwargs = {'tag': k, 'simple_value': float(v)}
                return self.tf.Summary.Value(**kwargs)
            summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
            event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
            event.step = self.step # is there any reason why you'd want to specify the step?
            self.writer.WriteEvent(event)
            self.writer.Flush()
            self.step += 1
    
        def close(self):
            if self.writer:
                self.writer.Close()
                self.writer = None
    
    def make_output_format(format, ev_dir, log_suffix=''):
        os.makedirs(ev_dir, exist_ok=True)
        if format == 'stdout':
            return HumanOutputFormat(sys.stdout)
        elif format == 'log':
            return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
        elif format == 'json':
            return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
        elif format == 'csv':
            return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
        elif format == 'tensorboard':
            return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
        else:
            raise ValueError('Unknown format specified: %s' % (format,))
    
    # ================================================================
    # API
    # ================================================================
    
    def logkv(key, val):
        """
        Log a value of some diagnostic
        Call this once for each diagnostic quantity, each iteration
        If called many times, last value will be used.
        """
        get_current().logkv(key, val)
    
    def logkv_mean(key, val):
        """
        The same as logkv(), but if called many times, values averaged.
        """
        get_current().logkv_mean(key, val)
    
    def logkvs(d):
        """
        Log a dictionary of key-value pairs
        """
        for (k, v) in d.items():
            logkv(k, v)
    
    def dumpkvs():
        """
        Write all of the diagnostics from the current iteration
        """
        return get_current().dumpkvs()
    
    def getkvs():
        return get_current().name2val
    
    
    def log(*args, level=INFO):
        """
        Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
        """
        get_current().log(*args, level=level)
    
    def debug(*args):
        log(*args, level=DEBUG)
    
    def info(*args):
        log(*args, level=INFO)
    
    def warn(*args):
        log(*args, level=WARN)
    
    def error(*args):
        log(*args, level=ERROR)
    
    
    def set_level(level):
        """
        Set logging threshold on current logger.
        """
        get_current().set_level(level)
    
    def set_comm(comm):
        get_current().set_comm(comm)
    
    def get_dir():
        """
        Get directory that log files are being written to.
        will be None if there is no output directory (i.e., if you didn't call start)
        """
        return get_current().get_dir()
    
    record_tabular = logkv
    dump_tabular = dumpkvs
    
    @contextmanager
    def profile_kv(scopename):
        logkey = 'wait_' + scopename
        tstart = time.time()
        try:
            yield
        finally:
            get_current().name2val[logkey] += time.time() - tstart
    
    def profile(n):
        """
        Usage:
        @profile("my_func")
        def my_func(): code
        """
        def decorator_with_name(func):
            def func_wrapper(*args, **kwargs):
                with profile_kv(n):
                    return func(*args, **kwargs)
            return func_wrapper
        return decorator_with_name
    
    
    # ================================================================
    # Backend
    # ================================================================
    
    def get_current():
        if Logger.CURRENT is None:
            _configure_default_logger()
    
        return Logger.CURRENT
    
    
    class Logger(object):
        DEFAULT = None  # A logger with no output files. (See right below class definition)
                        # So that you can still log to the terminal without setting up any output files
        CURRENT = None  # Current logger being used by the free functions above
    
        def __init__(self, dir, output_formats, comm=None):
            self.name2val = defaultdict(float)  # values this iteration
            self.name2cnt = defaultdict(int)
            self.level = INFO
            self.dir = dir
            self.output_formats = output_formats
            self.comm = comm
    
        # Logging API, forwarded
        # ----------------------------------------
        def logkv(self, key, val):
            self.name2val[key] = val
    
        def logkv_mean(self, key, val):
            oldval, cnt = self.name2val[key], self.name2cnt[key]
            self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
            self.name2cnt[key] = cnt + 1
    
        def dumpkvs(self):
            if self.comm is None:
                d = self.name2val
            else:
                from baselines.common import mpi_util
                d = mpi_util.mpi_weighted_mean(self.comm,
                    {name : (val, self.name2cnt.get(name, 1))
                        for (name, val) in self.name2val.items()})
                if self.comm.rank != 0:
                    d['dummy'] = 1 # so we don't get a warning about empty dict
            out = d.copy() # Return the dict for unit testing purposes
            for fmt in self.output_formats:
                if isinstance(fmt, KVWriter):
                    fmt.writekvs(d)
            self.name2val.clear()
            self.name2cnt.clear()
            return out
    
        def log(self, *args, level=INFO):
            if self.level <= level:
                self._do_log(args)
    
        # Configuration
        # ----------------------------------------
        def set_level(self, level):
            self.level = level
    
        def set_comm(self, comm):
            self.comm = comm
    
        def get_dir(self):
            return self.dir
    
        def close(self):
            for fmt in self.output_formats:
                fmt.close()
    
        # Misc
        # ----------------------------------------
        def _do_log(self, args):
            for fmt in self.output_formats:
                if isinstance(fmt, SeqWriter):
                    fmt.writeseq(map(str, args))
    
    def get_rank_without_mpi_import():
        # check environment variables here instead of importing mpi4py
        # to avoid calling MPI_Init() when this module is imported
        for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']:
            if varname in os.environ:
                return int(os.environ[varname])
        return 0
    
    
    def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
        """
        If comm is provided, average all numerical stats across that comm
        """
        if dir is None:
            dir = os.getenv('OPENAI_LOGDIR')
        if dir is None:
            dir = osp.join(tempfile.gettempdir(),
                datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
        assert isinstance(dir, str)
        dir = os.path.expanduser(dir)
        os.makedirs(os.path.expanduser(dir), exist_ok=True)
    
        rank = get_rank_without_mpi_import()
        if rank > 0:
            log_suffix = log_suffix + "-rank%03i" % rank
    
        if format_strs is None:
            if rank == 0:
                format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
            else:
                format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
        format_strs = filter(None, format_strs)
        output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
    
        Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
        if output_formats:
            log('Logging to %s'%dir)
    
    def _configure_default_logger():
        configure()
        Logger.DEFAULT = Logger.CURRENT
    
    def reset():
        if Logger.CURRENT is not Logger.DEFAULT:
            Logger.CURRENT.close()
            Logger.CURRENT = Logger.DEFAULT
            log('Reset logger')
    
    @contextmanager
    def scoped_configure(dir=None, format_strs=None, comm=None):
        prevlogger = Logger.CURRENT
        configure(dir=dir, format_strs=format_strs, comm=comm)
        try:
            yield
        finally:
            Logger.CURRENT.close()
            Logger.CURRENT = prevlogger
    
    # ================================================================
    
    def _demo():
        info("hi")
        debug("shouldn't appear")
        set_level(DEBUG)
        debug("should appear")
        dir = "/tmp/testlogging"
        if os.path.exists(dir):
            shutil.rmtree(dir)
        configure(dir=dir)
        logkv("a", 3)
        logkv("b", 2.5)
        dumpkvs()
        logkv("b", -2.5)
        logkv("a", 5.5)
        dumpkvs()
        info("^^^ should see a = 5.5")
        logkv_mean("b", -22.5)
        logkv_mean("b", -44.4)
        logkv("a", 5.5)
        dumpkvs()
        info("^^^ should see b = -33.3")
    
        logkv("b", -2.5)
        dumpkvs()
    
        logkv("a", "longasslongasslongasslongasslongasslongassvalue")
        dumpkvs()
    
    
    # ================================================================
    # Readers
    # ================================================================
    
    def read_json(fname):
        import pandas
        ds = []
        with open(fname, 'rt') as fh:
            for line in fh:
                ds.append(json.loads(line))
        return pandas.DataFrame(ds)
    
    def read_csv(fname):
        import pandas
        return pandas.read_csv(fname, index_col=None, comment='#')
    
    def read_tb(path):
        """
        path : a tensorboard file OR a directory, where we will find all TB files
               of the form events.*
        """
        import pandas
        import numpy as np
        from glob import glob
        import tensorflow as tf
        if osp.isdir(path):
            fnames = glob(osp.join(path, "events.*"))
        elif osp.basename(path).startswith("events."):
            fnames = [path]
        else:
            raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path)
        tag2pairs = defaultdict(list)
        maxstep = 0
        for fname in fnames:
            for summary in tf.train.summary_iterator(fname):
                if summary.step > 0:
                    for v in summary.summary.value:
                        pair = (summary.step, v.simple_value)
                        tag2pairs[v.tag].append(pair)
                    maxstep = max(summary.step, maxstep)
        data = np.empty((maxstep, len(tag2pairs)))
        data[:] = np.nan
        tags = sorted(tag2pairs.keys())
        for (colidx,tag) in enumerate(tags):
            pairs = tag2pairs[tag]
            for (step, value) in pairs:
                data[step-1, colidx] = value
        return pandas.DataFrame(data, columns=tags)
    
    if __name__ == "__main__":
        _demo()
    View Code

    这个模块代码较多,逻辑比较复杂,其实实现的功能还是比较简单的,就是把python中的字典类型数据格式化后打印到屏幕和文件中。

    由于这个模块写了主函数:

    所以可以直接以模块化的方式进行运行:

    python -m baselines.logger

    可以在/tmp路径下面看到保存的字典数据:

    这个日志模块感觉设计的过于复杂而且使用性较差,属于自己造轮子的做法,意义价值不到,不细分析了。

    ==============================================

  • 相关阅读:
    c语言中 fgetc函数、fputc函数实现文件的复制
    c语言 13-7 利用fgetc函数输出文件的字符数
    c语言 13-6 利用fgetc函数输出文件的行数
    c语言中fgetc函数:显示文件内容
    c语言 13-5
    c语言 获取程序上一次运行时间的程序
    hzwer模拟赛 虫洞
    LYDSY热身赛 escape
    bzoj2330 糖果
    繁华模拟赛 Vicent坐电梯
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/16059128.html
Copyright © 2020-2023  润新知