""".. currentmodule:: pgtoolkit.hba
This module supports reading, validating, editing and rendering ``pg_hba.conf``
file. See `Client Authentication
<https://www.postgresql.org/docs/current/static/auth-pg-hba-conf.html>`__ in
PostgreSQL documentation for details on format and values of ``pg_hba.conf``
file.
API Reference
-------------
The main entrypoint of this API is the :func:`parse` function. It returns a
:class:`HBA` object containing :class:`HBARecord` instances.
.. autofunction:: parse
.. autoclass:: HBA
.. autoclass:: HBARecord
Examples
--------
Loading a ``pg_hba.conf`` file :
.. code:: python
pgpass = parse('my_pg_hba.conf')
You can also pass a file-object:
.. code:: python
with open('my_pg_hba.conf', 'r') as fo:
hba = parse(fo)
Creating a ``pg_hba.conf`` file from scratch :
.. code:: python
hba = HBA()
record = HBARecord(
conntype='local', database='all', user='all', method='peer',
)
hba.lines.append(record)
with open('pg_hba.conf', 'w') as fo:
hba.save(fo)
Using as a script
-----------------
:mod:`pgtoolkit.hba` is usable as a CLI script. It accepts a pg_hba file path
as first argument, read it, validate it and re-render it. Fields are aligned to
fit pseudo-column width. If filename is ``-``, stdin is read instead.
.. code:: console
$ python -m pgtoolkit.hba - < data/pg_hba.conf
# TYPE DATABASE USER ADDRESS METHOD
# "local" is for Unix domain socket connections only
local all all trust
# IPv4 local connections:
host all all 127.0.0.1/32 ident map=omicron
""" # noqa
from __future__ import annotations
import os
import re
import sys
import warnings
from collections.abc import Iterable, Iterator, Sequence
from pathlib import Path
from typing import IO, Any, Callable
from ._helpers import open_or_return, open_or_stdin
from .errors import ParseError
class HBAComment(str):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self:.32}>"
[docs]
class HBARecord:
"""Holds a HBA record composed of fields and a comment.
Common fields are accessible through attribute : ``conntype``,
``databases``, ``users``, ``address``, ``netmask``, ``method``.
Auth-options fields are also accessible through attribute like ``map``,
``ldapserver``, etc.
``address`` and ``netmask`` fields are not always defined. If not,
accessing undefined attributes trigger an :exc:`AttributeError`.
``databases`` and ``users`` have a single value variant respectively
:attr:`database` and :attr:`user`, computed after the list representation
of the field.
.. automethod:: parse
.. automethod:: __init__
.. automethod:: __str__
.. automethod:: matches
.. autoattribute:: database
.. autoattribute:: user
"""
COMMON_FIELDS = [
"conntype",
"databases",
"users",
"address",
"netmask",
"method",
]
CONNECTION_TYPES = [
"local",
"host",
"hostssl",
"hostnossl",
"hostgssenc",
"hostnogssenc",
]
[docs]
@classmethod
def parse(cls, line: str) -> HBARecord:
"""Parse a HBA record
:rtype: :class:`HBARecord` or a :class:`str` for a comment or blank
line.
:raises ValueError: If connection type is wrong.
"""
line = line.strip()
record_fields = ["conntype", "databases", "users"]
# What the regexp below does is finding all elements separated by spaces
# unless they are enclosed in double-quotes
# (?: … )+ = non-capturing group
# \"+.*?\"+ = any element with or without spaces enclosed within
# double-quotes (alternative 1)
# \S = any non-whitespace character (alternative 2)
values = [p for p in re.findall(r"(?:\"+.*?\"+|\S)+", line) if p.strip()]
# Split databases and users lists.
values[1] = values[1].split(",")
values[2] = values[2].split(",")
try:
hash_pos = values.index("#")
except ValueError:
comment = None
else:
values, comments = values[:hash_pos], values[hash_pos:]
comment = " ".join(comments[1:])
if values[0] not in cls.CONNECTION_TYPES:
raise ValueError("Unknown connection type '%s'" % values[0])
if "local" != values[0]:
record_fields.append("address")
common_values = [v for v in values if "=" not in v]
if len(common_values) >= 6:
record_fields.append("netmask")
record_fields.append("method")
base_options = list(zip(record_fields, values[: len(record_fields)]))
auth_options = [o.split("=", 1) for o in values[len(record_fields) :]]
# Remove extra outer double quotes for auth options values if any
auth_options = [(o[0], re.sub(r"^\"|\"$", "", o[1])) for o in auth_options]
options = base_options + auth_options
return cls(options, comment=comment)
conntype: str | None
databases: list[str]
users: list[str]
[docs]
def __init__(
self,
values: Iterable[tuple[str, str]] | dict[str, Any] | None = None,
comment: str | None = None,
**kw_values: str | Sequence[str],
) -> None:
"""
:param values: A dict of fields.
:param kw_values: Fields passed as keyword.
:param comment: Comment at the end of the line.
"""
dict_values: dict[str, Any] = dict(values or {}, **kw_values)
if "database" in dict_values:
dict_values["databases"] = [dict_values.pop("database")]
if "user" in dict_values:
dict_values["users"] = [dict_values.pop("user")]
self.__dict__.update(dict_values)
self.fields = [k for k, _ in dict_values.items()]
self.comment = comment
def __repr__(self) -> str:
return "<{} {}{}>".format(
self.__class__.__name__,
" ".join(self.common_values),
"..." if self.auth_options else "",
)
[docs]
def __str__(self) -> str:
"""Serialize a record line, without EOL."""
# Stolen from default pg_hba.conf
widths = [8, 16, 16, 16, 8]
fmt = ""
for i, field in enumerate(self.COMMON_FIELDS):
try:
width = widths[i]
except IndexError:
width = 0
if field not in self.fields:
fmt += " " * width
continue
if width:
fmt += "%%(%s)-%ds " % (field, width - 1)
else:
fmt += f"%({field})s "
# Serialize database and user list using property.
values = dict(self.__dict__, databases=self.database, users=self.user)
line = fmt.rstrip() % values
auth_options = ['%s="%s"' % i for i in self.auth_options]
if auth_options:
line += " " + " ".join(auth_options)
if self.comment is not None:
line += " # " + self.comment
else:
line = line.rstrip()
return line
def __eq__(self, other: object) -> bool:
return str(self) == str(other)
def as_dict(self, serialized: bool = False) -> dict[str, Any]:
str_fields = self.COMMON_FIELDS[:]
if serialized:
str_fields[1:3] = ["database", "user"]
return {f: getattr(self, f) for f in str_fields if hasattr(self, f)}
@property
def common_values(self) -> list[str]:
str_fields = self.COMMON_FIELDS[:]
# Use serialized variant.
str_fields[1:3] = ["database", "user"]
return [getattr(self, f) for f in str_fields if f in self.fields]
@property
def auth_options(self) -> list[tuple[str, str]]:
return [
(f, getattr(self, f)) for f in self.fields if f not in self.COMMON_FIELDS
]
@property
def database(self) -> str:
"""Hold database column as a single value.
Use `databases` attribute to get parsed database list. `database` is
guaranteed to be a string.
"""
return ",".join(self.databases)
@property
def user(self) -> str:
"""Hold user column as a single value.
Use ``users`` property to get parsed user list. ``user`` is guaranteed
to be a string.
"""
return ",".join(self.users)
[docs]
def matches(self, **attrs: str) -> bool:
"""Tells if the current record is matching provided attributes.
:param attrs: keyword/values pairs corresponding to one or more
HBARecord attributes (ie. user, conntype, etc…)
"""
# Provided attributes should be comparable to HBARecord attributes
for k in attrs.keys():
if k not in self.COMMON_FIELDS + ["database", "user"]:
raise AttributeError("%s is not a valid attribute" % k)
for k, v in attrs.items():
if getattr(self, k, None) != v:
return False
return True
[docs]
class HBA:
"""Represents pg_hba.conf records
.. attribute:: lines
List of :class:`HBARecord` and comments.
.. attribute:: path
Path to a file. Is automatically set when calling :meth:`parse` with a
path to a file. :meth:`save` will write to this file if set.
.. automethod:: __iter__
.. automethod:: parse
.. automethod:: save
.. automethod:: remove
.. automethod:: merge
"""
lines: list[HBAComment | HBARecord]
path: str | Path | None
def __init__(self, entries: Iterable[HBAComment | HBARecord] | None = None) -> None:
"""HBA constructor
:param entries: A list of HBAComment or HBARecord. Optional.
"""
if entries and not isinstance(entries, list):
raise ValueError("%s should be a list" % entries)
self.lines = list(entries) if entries is not None else []
self.path = None
[docs]
def __iter__(self) -> Iterator[HBARecord]:
"""Iterate on records, ignoring comments and blank lines."""
for line in self.lines:
if isinstance(line, HBARecord):
yield line
[docs]
def parse(self, fo: Iterable[str]) -> None:
"""Parse records and comments from file object
:param fo: An iterable returning lines
"""
for i, line in enumerate(fo):
stripped = line.lstrip()
record: HBARecord | HBAComment
if not stripped or stripped.startswith("#"):
record = HBAComment(line.replace(os.linesep, ""))
else:
try:
record = HBARecord.parse(line)
except Exception as e:
raise ParseError(1 + i, line, str(e))
self.lines.append(record)
[docs]
def save(self, fo: str | Path | IO[str] | None = None) -> None:
"""Write records and comments in a file
:param fo: a file-like object. Is not required if :attr:`path` is set.
Line order is preserved. Record fields are vertically aligned to match
the columen size of column headers from default configuration file.
.. code::
# TYPE DATABASE USER ADDRESS METHOD
local all all trust
""" # noqa
with open_or_return(fo or self.path, mode="w") as fo:
for line in self.lines:
fo.write(str(line) + os.linesep)
[docs]
def remove(
self,
filter: Callable[[HBARecord], bool] | None = None,
**attrs: str,
) -> bool:
"""Remove records matching the provided attributes.
One can for example remove all records for which user is 'david'.
:param filter: a function to be used as filter. It is passed the record
to test against. If it returns True, the record is removed. It is
kept otherwise.
:param attrs: keyword/values pairs correspond to one or more
HBARecord attributes (ie. user, conntype, etc...)
:returns: ``True`` if records have changed.
Usage examples:
.. code:: python
hba.remove(filter=lamdba r: r.user == 'david')
hba.remove(user='david')
"""
if filter is not None and len(attrs.keys()):
warnings.warn("Only filter will be taken into account")
# Attributes list to look for must not be empty
if filter is None and not len(attrs.keys()):
raise ValueError("Attributes dict cannot be empty")
filter = filter or (lambda line: line.matches(**attrs))
lines_before = self.lines
self.lines = [
line
for line in self.lines
if not (isinstance(line, HBARecord) and filter(line))
]
return lines_before != self.lines
[docs]
def merge(self, other: HBA) -> bool:
"""Add new records to HBAFile or replace them if they are matching
(ie. same conntype, database, user and address)
:param other: HBAFile to merge into the current one.
Lines with matching conntype, database, user and database will be
replaced by the new one. Otherwise they will be added at the end.
Comments from the original hba are preserved.
:returns: ``True`` if records have changed.
"""
lines = self.lines[:]
new_lines = other.lines[:]
other_comments = []
for i, line in enumerate(lines):
if isinstance(line, HBAComment):
continue
for new_line in new_lines:
if isinstance(new_line, HBAComment):
# preserve comments until next record
other_comments.append(new_line)
else:
kwargs = dict()
for a in ["conntype", "database", "user", "address"]:
if hasattr(new_line, a):
kwargs[a] = getattr(new_line, a)
if line.matches(**kwargs):
# replace matched line with comments + record
self.lines[i : i + 1] = other_comments + [new_line]
for c in other_comments:
new_lines.remove(c)
new_lines.remove(new_line)
break # found match, go to next line
other_comments[:] = []
# Then add remaining new lines (not merged)
self.lines.extend(new_lines)
return lines != self.lines
[docs]
def parse(file: str | Iterable[str] | Path) -> HBA:
"""Parse a `pg_hba.conf` file.
:param file: Either a line iterator such as a file-like object, a path or a string
corresponding to the path to the file to open and parse.
:rtype: :class:`HBA`.
"""
if isinstance(file, (str, Path)):
with open(file) as fo:
hba = parse(fo)
hba.path = file
else:
hba = HBA()
hba.parse(file)
return hba
if __name__ == "__main__": # pragma: nocover
argv = sys.argv[1:] + ["-"]
try:
with open_or_stdin(argv[0]) as fo:
hba = parse(fo)
hba.save(sys.stdout)
except Exception as e:
print(str(e), file=sys.stderr)
exit(1)