Source code for pgtoolkit.conf

"""\
.. currentmodule:: pgtoolkit.conf

This module implements ``postgresql.conf`` file format. This is the same format
for ``recovery.conf``. The main entry point of the API is :func:`parse`. The
module can be used as a CLI script.


API Reference
-------------

.. autofunction:: parse
.. autofunction:: parse_string
.. autoclass:: Configuration
.. autoclass:: ParseError


Using as a CLI Script
---------------------

You can use this module to dump a configuration file as JSON object

.. code:: console

    $ python -m pgtoolkit.conf postgresql.conf | jq .
    {
      "lc_monetary": "fr_FR.UTF8",
      "datestyle": "iso, dmy",
      "log_rotation_age": "1d",
      "log_min_duration_statement": "3s",
      "log_lock_waits": true,
      "log_min_messages": "notice",
      "log_directory": "log",
      "port": 5432,
      "log_truncate_on_rotation": true,
      "log_rotation_size": 0
    }
    $

"""

from __future__ import annotations

import contextlib
import copy
import enum
import json
import pathlib
import re
import sys
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, field
from datetime import timedelta
from typing import IO, Any, ClassVar, NoReturn, Union
from warnings import warn

from ._helpers import JSONDateEncoder, open_or_return


[docs] class ParseError(Exception): """Error while parsing configuration content."""
class IncludeType(enum.Enum): """Include directive types. https://www.postgresql.org/docs/13/config-setting.html#CONFIG-INCLUDES """ include_dir = enum.auto() include_if_exists = enum.auto() include = enum.auto()
[docs] def parse(fo: str | pathlib.Path | IO[str]) -> Configuration: """Parse a configuration file. The parser tries to return Python object corresponding to value, based on some heuristics. booleans, octal number, decimal integers and floating point numbers are parsed. Multiplier units like kB or MB are applied and you get an int. Interval value like ``3s`` are returned as :class:`datetime.timedelta`. In case of doubt, the value is kept as a string. It's up to you to enforce format. Include directives are processed recursively, when 'fo' is a file path (not a file object). If some included file is not found a FileNotFoundError exception is raised. If a loop is detected in include directives, a RuntimeError is raised. :param fo: A line iterator such as a file-like object or a path. :returns: A :class:`Configuration` containing parsed configuration. """ with open_or_return(fo) as f: conf = Configuration(getattr(f, "name", None)) list(_consume(conf, f)) return conf
def _consume(conf: Configuration, content: Iterable[str]) -> Iterator[None]: for include_path, include_type in conf.parse(content): yield from parse_include(conf, include_path, include_type)
[docs] def parse_string(string: str, source: str | None = None) -> Configuration: """Parse configuration data from a string. Optional *source* argument can be used to set the context path of built Configuration. :raises ParseError: if the string contains include directives referencing a relative path and *source* is unspecified. """ conf = Configuration(source) conf.parse_string(string) return conf
def parse_include( conf: Configuration, path: pathlib.Path, include_type: IncludeType, *, _processed: set[pathlib.Path] | None = None, ) -> Iterator[None]: """Parse on include directive with 'path' value of type 'include_type' into 'conf' object. """ if _processed is None: _processed = set() def notfound( path: pathlib.Path, include_type: str, reference_path: str | None ) -> FileNotFoundError: ref = ( f"{reference_path!r}" if reference_path is not None else "<string literal>" ) return FileNotFoundError( f"{include_type} '{path}', included from {ref}, not found" ) if not path.is_absolute(): if not conf.path: raise ParseError( "cannot process include directives referencing a relative path" ) relative_to = pathlib.Path(conf.path).absolute() assert relative_to.is_absolute() if relative_to.is_file(): relative_to = relative_to.parent path = relative_to / path if include_type == IncludeType.include_dir: if not path.exists() or not path.is_dir(): raise notfound(path, "directory", conf.path) for confpath in sorted(path.glob("*.conf")): if not confpath.name.startswith("."): yield from parse_include( conf, confpath, IncludeType.include, _processed=_processed, ) elif include_type == IncludeType.include_if_exists: if path.exists(): yield from parse_include( conf, path, IncludeType.include, _processed=_processed ) elif include_type == IncludeType.include: if not path.exists(): raise notfound(path, "file", conf.path) if path in _processed: raise RuntimeError(f"loop detected in include directive about '{path}'") _processed.add(path) subconf = Configuration(path=str(path)) with path.open() as f: for sub_include_path, sub_include_type in subconf.parse(f): yield from parse_include( subconf, sub_include_path, sub_include_type, _processed=_processed, ) conf.entries.update(subconf.entries) else: assert False, include_type # pragma: nocover MEMORY_MULTIPLIERS = { "kB": 1024, "MB": 1024 * 1024, "GB": 1024 * 1024 * 1024, "TB": 1024 * 1024 * 1024 * 1024, } _memory_re = re.compile(r"^\s*(?P<number>\d+)\s*(?P<unit>[kMGT]B)\s*$") TIMEDELTA_ARGNAME = { "ms": "milliseconds", "s": "seconds", "min": "minutes", "h": "hours", "d": "days", } _timedelta_re = re.compile(r"^\s*(?P<number>\d+)\s*(?P<unit>ms|s|min|h|d)\s*$") _minute = 60 _hour = 60 * _minute _day = 24 * _hour _timedelta_unit_map = [ ("d", _day), ("h", _hour), # The space before 'min' is intentional. I find '1 min' more readable # than '1min'. (" min", _minute), ("s", 1), ] Value = Union[str, bool, float, int, timedelta] def parse_value(raw: str) -> Value: # Ref. # https://www.postgresql.org/docs/current/static/config-setting.html#CONFIG-SETTING-NAMES-VALUES quoted = False if raw.startswith("'"): if not raw.endswith("'"): raise ValueError(raw) # unquote value and unescape quotes raw = raw[1:-1].replace("''", "'").replace(r"\'", "'") quoted = True if raw.startswith("0") and raw != "0": try: int(raw, base=8) return raw except ValueError: pass m = _memory_re.match(raw) if m: return raw.strip() m = _timedelta_re.match(raw) if m: unit = m.group("unit") arg = TIMEDELTA_ARGNAME[unit] kwargs = {arg: int(m.group("number"))} return timedelta(**kwargs) if raw.lower() in ("true", "yes", "on"): return True if raw.lower() in ("false", "no", "off"): return False if not quoted: try: return int(raw) except ValueError: try: return float(raw) except ValueError: return raw return raw def serialize_value(value: Value) -> str: # This is the reverse of parse_value. if isinstance(value, bool): value = "on" if value else "off" elif isinstance(value, str): # Only quote if not already quoted. if not (value.startswith("'") and value.endswith("'")): # Only double quotes, if not already done; we assume this is # done everywhere in the string or nowhere. if "''" not in value and r"\'" not in value: value = value.replace("'", "''") value = "'%s'" % value elif isinstance(value, timedelta): seconds = value.days * _day + value.seconds if value.microseconds: unit = " ms" value = seconds * 1000 + value.microseconds // 1000 else: for unit, mod in _timedelta_unit_map: if seconds % mod: continue value = seconds // mod break value = f"'{value}{unit}'" else: value = str(value) return value _unspecified: Any = object() @dataclass class Entry: """Configuration entry, parsed from a line in the configuration file.""" name: str _value: Value # _: KW_ONLY from Python 3.10 commented: bool = False comment: str | None = None raw_line: str = field(default=_unspecified, compare=False, repr=False) def __post_init__(self) -> None: if self.raw_line is _unspecified: # We parse value only if not already parsed from a file if isinstance(self._value, str): self._value = parse_value(self._value) # Store the raw_line to track the position in the list of lines. self.raw_line = str(self) + "\n" @property def value(self) -> Value: return self._value @value.setter def value(self, value: str | Value) -> None: if isinstance(value, str): value = parse_value(value) self._value = value def serialize(self) -> str: return serialize_value(self.value) def __str__(self) -> str: line = "%(name)s = %(value)s" % dict(name=self.name, value=self.serialize()) if self.comment: line += " # " + self.comment if self.commented: line = "#" + line return line class EntriesProxy(dict[str, Entry]): """Proxy object used during Configuration edition. >>> p = EntriesProxy(port=Entry('port', '5432'), ... shared_buffers=Entry('shared_buffers', '1GB')) Existing entries can be edited: >>> p['port'].value = '5433' New entries can be added as: >>> p.add('listen_addresses', '*', commented=True, comment='IP address') >>> p # doctest: +NORMALIZE_WHITESPACE {'port': Entry(name='port', _value=5433, commented=False, comment=None), 'shared_buffers': Entry(name='shared_buffers', _value='1GB', commented=False, comment=None), 'listen_addresses': Entry(name='listen_addresses', _value='*', commented=True, comment='IP address')} >>> del p['shared_buffers'] >>> p # doctest: +NORMALIZE_WHITESPACE {'port': Entry(name='port', _value=5433, commented=False, comment=None), 'listen_addresses': Entry(name='listen_addresses', _value='*', commented=True, comment='IP address')} Adding an existing entry fails: >>> p.add('port', 5433) Traceback (most recent call last): ... ValueError: 'port' key already present So does adding a value to the underlying dict: >>> p['bonjour_name'] = 'pgserver' Traceback (most recent call last): ... TypeError: cannot set a key """ def __setitem__(self, key: str, value: Any) -> NoReturn: raise TypeError("cannot set a key") def add( self, name: str, value: Value, *, commented: bool = False, comment: str | None = None, ) -> None: """Add a new entry.""" if name in self: raise ValueError(f"'{name}' key already present") entry = Entry(name, value, commented=commented, comment=comment) super().__setitem__(name, entry)
[docs] @dataclass class Configuration: r"""Holds a parsed configuration. You can access parameter using attribute or dictionary syntax. >>> conf = parse(['port=5432\n', 'pg_stat_statement.min_duration = 3s\n']) >>> conf.port 5432 >>> conf.port = 5433 >>> conf.port 5433 >>> conf['port'] = 5434 >>> conf.port 5434 >>> conf['pg_stat_statement.min_duration'].total_seconds() 3.0 >>> conf.get("ssl") >>> conf.get("ssl", False) False Configuration instances can be merged: >>> otherconf = parse(["listen_addresses='*'\n", "port = 5454\n"]) >>> sumconf = conf + otherconf >>> print(json.dumps(sumconf.as_dict(), cls=JSONDateEncoder, indent=2)) { "port": 5454, "pg_stat_statement.min_duration": "3s", "listen_addresses": "*" } though, lines are discarded in the operation: >>> sumconf.lines [] >>> conf += otherconf >>> print(json.dumps(conf.as_dict(), cls=JSONDateEncoder, indent=2)) { "port": 5454, "pg_stat_statement.min_duration": "3s", "listen_addresses": "*" } >>> conf.lines [] .. attribute:: path Path to a file. Automatically set when calling :func:`parse` with a path to a file. This is default target for :meth:`save`. .. automethod:: edit .. automethod:: save """ # noqa # Internally, lines property contains an updated list of all comments and # entries serialized. When adding a setting or updating an existing one, # the serialized line is updated accordingly. This allows to keep comments # and serialize only what's needed. Other lines are just written as-is. path: str | None = None lines: list[str] = field(default_factory=list, init=False) entries: dict[str, Entry] = field(default_factory=OrderedDict, init=False) _parameter_re: ClassVar = re.compile( r"^(?P<name>[a-z_.]+)(?: +(?!=)| *= *)(?P<value>.*?)" "[\\s\t]*" r"(?P<comment>#.*)?$" ) def parse(self, fo: Iterable[str]) -> Iterator[tuple[pathlib.Path, IncludeType]]: for raw_line in fo: self.lines.append(raw_line) line = raw_line.strip() if not line: continue commented = False if line.startswith("#"): # Try to parse the commented line as a commented parameter, # but only if in the form of 'name = value' since we cannot # discriminate a commented sentence (with whitespaces) from a # commented parameter in the form of 'name value'. if "=" not in line: continue line = line.lstrip("#").lstrip() m = self._parameter_re.match(line) if not m: # This is a real comment continue commented = True else: m = self._parameter_re.match(line) if not m: raise ValueError("Bad line: %r." % raw_line) kwargs = m.groupdict() name = kwargs.pop("name") value = parse_value(kwargs.pop("value")) if name in IncludeType.__members__: if not commented: include_type = IncludeType[name] assert isinstance(value, str), type(value) yield (pathlib.Path(value), include_type) else: comment = kwargs["comment"] if comment is not None: kwargs["comment"] = comment.lstrip("#").lstrip() if commented: # Only overwrite a previous entry if it is commented. try: existing_entry = self.entries[name] except KeyError: pass else: if not existing_entry.commented: continue self.entries[name] = Entry( name, value, commented=commented, raw_line=raw_line, **kwargs ) def parse_string(self, string: str) -> None: list(_consume(self, string.splitlines(keepends=True))) def __add__(self, other: Any) -> Configuration: cls = self.__class__ if not isinstance(other, cls): return NotImplemented s = cls() s.entries.update(self.entries) s.entries.update(other.entries) return s def __iadd__(self, other: Any) -> Configuration: cls = self.__class__ if not isinstance(other, cls): return NotImplemented self.lines[:] = [] self.entries.update(other.entries) return self def __getattr__(self, name: str) -> Value: try: return self.entries[name].value except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Value) -> None: if name in self.__dataclass_fields__: super().__setattr__(name, value) else: self[name] = value def __contains__(self, key: str) -> bool: return key in self.entries def __getitem__(self, key: str) -> Value: return self.entries[key].value def __setitem__(self, key: str, value: Value) -> None: if key in IncludeType.__members__: raise ValueError("cannot add an include directive") if key in self.entries: e = self.entries[key] e.value = value self._update_entry(e) else: self._add_entry(Entry(key, value)) def get(self, key: str, default: Value | None = None) -> Value | None: try: return self[key] except KeyError: return default def _add_entry(self, entry: Entry) -> None: assert entry.name not in self.entries self.entries[entry.name] = entry # Append serialized line. entry.raw_line = str(entry) + "\n" self.lines.append(entry.raw_line) def _update_entry(self, entry: Entry) -> None: key = entry.name old_entry, self.entries[key] = self.entries[key], entry if old_entry.commented: # If the entry was previously commented, we uncomment it (assuming # that setting a value to a commented entry does not make much # sense.) entry.commented = False # Update serialized entry. old_line = old_entry.raw_line entry.raw_line = str(entry) + "\n" try: lineno = self.lines.index(old_line) except ValueError: if not entry.commented: msg = ( f"entry {key!r} not directly found in {self.path or 'parsed content'}" " (it might be defined in an included file)," " appending a new line to set requested value" ) warn(msg, UserWarning) self.lines.append(entry.raw_line) else: self.lines[lineno : lineno + 1] = [entry.raw_line] def __iter__(self) -> Iterator[Entry]: return iter(self.entries.values()) def as_dict(self) -> dict[str, Value]: return {k: v.value for k, v in self.entries.items() if not v.commented}
[docs] @contextlib.contextmanager def edit(self) -> Iterator[EntriesProxy]: r"""Context manager allowing edition of the Configuration instance. >>> import sys >>> cfg = Configuration() >>> includes = cfg.parse([ ... "#listen_addresses = 'localhost' # what IP address(es) to listen on;\n", ... " # comma-separated list of addresses;\n", ... "port = 5432 # (change requires restart)\n", ... "max_connections = 100 # (change requires restart)\n", ... ]) >>> list(includes) [] >>> cfg.save(sys.stdout) #listen_addresses = 'localhost' # what IP address(es) to listen on; # comma-separated list of addresses; port = 5432 # (change requires restart) max_connections = 100 # (change requires restart) >>> with cfg.edit() as entries: ... entries["port"].value = 2345 ... entries["port"].comment = None ... entries["listen_addresses"].value = '*' ... del entries["max_connections"] ... entries.add( ... "unix_socket_directories", ... "'/var/run/postgresql'", ... comment="comma-separated list of directories", ... ) >>> cfg.save(sys.stdout) listen_addresses = '*' # what IP address(es) to listen on; # comma-separated list of addresses; port = 2345 unix_socket_directories = '/var/run/postgresql' # comma-separated list of directories """ # noqa: E501 entries = EntriesProxy({k: copy.copy(v) for k, v in self.entries.items()}) try: yield entries except Exception: raise else: # Add or update entries. for k, entry in entries.items(): assert isinstance(entry, Entry), "expecting Entry values" if k not in self: self._add_entry(entry) elif self.entries[k] != entry: self._update_entry(entry) # Discard removed entries. for k, entry in list(self.entries.items()): if k not in entries: del self.entries[k] if entry.raw_line is not None: self.lines.remove(entry.raw_line)
[docs] def save(self, fo: str | pathlib.Path | IO[str] | None = None) -> None: """Write configuration to a file. Configuration entries order and comments are preserved. :param fo: A path or file-like object. Required if :attr:`path` is None. """ with open_or_return(fo or self.path, mode="w") as fo: for line in self.lines: fo.write(line)
def _main(argv: list[str]) -> int: # pragma: nocover try: conf = parse(argv[0] if argv else sys.stdin) print(json.dumps(conf.as_dict(), cls=JSONDateEncoder, indent=2)) return 0 except Exception as e: print(str(e), file=sys.stderr) return 1 if __name__ == "__main__": # pragma: nocover exit(_main(sys.argv[1:]))