File: //opt/imh-cwp-dns/zone_interface.py
import re
from typing import NamedTuple, Optional, Union
from datetime import datetime
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.zone
class TextLine(NamedTuple):
record_raw: list[str]
class Record(NamedTuple):
line_start: int
line_end: int
subdomain: str
record_type: str
value_raw: str
value_parsed: str
record_raw: list[str]
ttl: Optional[int] = 900
class ZoneFile:
"""
Simple interface for DNS zone files
"""
def __init__(self, file_path, domain_name):
self.file_path = file_path
self.records: list[Union[Record, TextLine]] = []
self.domain_name = domain_name
with open(self.file_path, "r") as zone_file:
self.original_lines = zone_file.readlines()
last_line = self.original_lines[-1]
if last_line and not last_line.endswith("\n"):
# ensure the file ends with a newline, if it doesn't
self.original_lines[-1] = last_line + "\n"
def _is_root(self, target):
'''
return true if the target represents the root of the zone
'''
return target == "@" or target == f"{self.domain_name}."
def parse(self, print_warnings=False):
"""
Parse the zone file into records and comments
"""
lines = self.original_lines
skip_until = -1
for i, line in enumerate(lines):
if i < skip_until:
continue
if i >= skip_until and skip_until != -1:
skip_until = -1
line = line.lstrip()
if not line or line.startswith(";") or line.startswith("$"):
self.records.append(TextLine(record_raw=[lines[i]]))
continue
parsed = re.match(
r"^(\S+)\s+" # subdomain
r"(\d+)?\s*" # optional TTL
r"IN\s+" # class
r"([a-zA-Z]+)\s+" # record type
r"(\S.*\S)$", # record
line,
re.IGNORECASE,
)
if not parsed:
if print_warnings:
print(f"Warning: Could not parse line {i + 1}: {line}")
self.records.append(TextLine(record_raw=[lines[i]]))
continue
sub, ttl, rr_type, value = parsed.groups()
sub = sub.lower()
rr_type = rr_type.upper()
no_quotes = self._strip_quotes(value, discard=True)
# we see this in soa records, but it actually can happen in
# any record type
if "(" in no_quotes and ")" not in no_quotes:
# the regex stripped our newline so we add it back
value += "\n"
x = i
while x + 1 < len(lines):
x += 1
next_line = lines[x]
# remove comments and quoted strings
cleaned_line = self._strip_quotes(next_line, discard=True)
if not cleaned_line:
continue
value += next_line
if ")" in cleaned_line:
skip_until = x + 1
break
try:
value_parsed = dns.rdata.from_text(
dns.rdataclass.IN, dns.rdatatype.RdataType[rr_type], value.strip()
).to_text()
except Exception as e:
raise Exception(
f"Error parsing {self.domain_name} {sub} {rr_type} {value}"
) from e
if rr_type == "TXT":
value_parsed = self._unsplit_txt(value_parsed)
self.records.append(
Record(
line_start=i,
line_end=skip_until if skip_until != -1 else i,
subdomain=sub,
record_type=rr_type,
value_raw=value.strip(),
value_parsed=value_parsed,
record_raw=lines[i : (skip_until if skip_until != -1 else i) + 1],
ttl=int(ttl) if ttl else 900,
)
)
def encode_txt(self, txt: str, max_len=255):
"""
Perform rfc compilant TXT record splitting and return the combined
tring ready to insert into a zone file
"""
chunks = len(txt) // max_len
result = []
for i in range(chunks + 1):
record = txt[i * max_len : (i + 1) * max_len]
record = record.replace('"', r"\"")
result.append(f'"{record}"')
return " ".join(result)
def generate(self):
new_lines = []
for r in self.records:
new_lines.extend(r.record_raw)
result = "".join(new_lines)
# validate the result
try:
dns.zone.from_text(result, self.domain_name)
except Exception as e:
raise Exception(f"Updated zone file {self.domain_name} is invalid") from e
return result
def find_record(self, subdomain, record_type, value_regex=None) -> list[Record]:
matches = []
tgt_is_root = self._is_root(subdomain)
for record in self.records:
if isinstance(record, TextLine):
continue
if (
tgt_is_root and self._is_root(record.subdomain)
) or subdomain == record.subdomain:
if record.record_type == record_type:
if value_regex:
if re.search(value_regex, record.value_raw):
matches.append(record)
else:
matches.append(record)
return matches
def remove_record(self, record: Record):
self.records.remove(record)
return True
def replace_record_value(self, record: Record, value_raw: str, ttl=None):
if not ttl:
ttl = record.ttl
record_raw = (
f"{record.subdomain}\t{ttl}\tIN\t{record.record_type}\t{value_raw}\n"
)
for i in range(len(self.records)):
if self.records[i] == record:
copy = record._asdict()
copy.pop("record_raw")
self.records[i] = Record(**copy, record_raw=[record_raw])
return True
return False
def insert_record(self, subdomain: str, record_type: str, value_raw: str, ttl=900):
record = f"{subdomain}\t{ttl}\tIN\t{record_type}\t{value_raw}\n"
self.records.append(
Record(
line_start=-1,
line_end=-1,
subdomain=subdomain,
record_type=record_type,
value_raw=value_raw,
value_parsed=None,
record_raw=[record],
ttl=900,
)
)
def bump_soa_serial(self):
"""
Update SOA serial according to date
"""
find_result = self.find_record("@", "SOA")
if not find_result:
raise Exception("No SOA record found")
if len(find_result) != 1:
raise Exception("Multiple SOA records found")
for record in find_result:
value_parts = "".join(record.value_parsed).split()
try:
original_serial = int(value_parts[2])
except ValueError as e:
raise Exception(f"Invalid SOA serial number in {value_parts}") from e
min_serial = int(datetime.now().strftime("%Y%m%d") + "00")
serial_boundary = min_serial + 99
if original_serial < min_serial: # 2024010203 < 2024020200
new_serial = min_serial
elif original_serial > serial_boundary: # 2024010203 > 2024020299
new_serial = serial_boundary
elif original_serial == serial_boundary: # 2024010299 == 2024010299
new_serial = min_serial # rotate back to min_serial
else: # serial is in range
new_serial = original_serial + 1
self.replace_record_value(
record, record.value_raw.replace(str(original_serial), str(new_serial))
)
return new_serial
def save(self):
with open(self.file_path, "w") as zone_file:
zone_file.write(self.generate())
def _strip_quotes(self, value, discard=False, strip_whitespace=False):
"""
Remove comments and optionally quoted strings from a DNS record value
If discard is True, the content of the quoted strings is also removed.
"""
new_v = []
in_quote = False
skip_next = False
for i, c in enumerate(value):
if c == ";" and not in_quote:
break # this is a comment
if strip_whitespace and (not in_quote and c in (" ", "\t")):
# this would be a delimeter.
# in txt record, however, it's just delimiting the split record
# when parsing a txt record we would discard this
continue
if skip_next:
skip_next = False
new_v.append(c)
continue
# escape
if c == "\\":
skip_next = True
continue
if c == '"':
if not in_quote:
in_quote = True
continue
else:
in_quote = False
continue
if in_quote and discard:
continue
new_v.append(c)
return "".join(new_v)
def _unsplit_txt(self, txt: str):
"""
Get "literal" value of a TXT record, joining split records and
removing quotes
"""
return self._strip_quotes(txt, discard=False, strip_whitespace=True)