Skip to content

Commit

Permalink
Extend import functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jubich committed Mar 21, 2023
1 parent 0d97116 commit 33f9542
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 14 deletions.
14 changes: 12 additions & 2 deletions src/hsd/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hsd.common import HSD_ATTRIB_NAME, np, ATTRIB_SUFFIX, HSD_ATTRIB_SUFFIX, HsdError,\
QUOTING_CHARS, SPECIAL_CHARS
from hsd.eventhandler import HsdEventHandler, HsdEventPrinter
from hsd.interrupts import IncludeHsd, IncludeText

_ItemType = Union[float, complex, int, bool, str]

Expand Down Expand Up @@ -158,7 +159,10 @@ def add_text(self, text):
if self._curblock or self._data is not None:
msg = "Data appeared in an invalid context"
raise HsdError(msg)
self._data = self._text_to_data(text)
if isinstance(text, IncludeText) or isinstance(text, IncludeHsd):
self._data = text
else:
self._data = self._text_to_data(text)


def _text_to_data(self, txt: str) -> _DataType:
Expand Down Expand Up @@ -243,8 +247,14 @@ def walk(self, dictobj):
self._eventhandler.close_tag(key)

else:
nextline = False
self._eventhandler.open_tag(key, attrib, hsdattrib)
self._eventhandler.add_text(_to_text(value))
if isinstance(value, IncludeHsd) or isinstance(value, IncludeText):
text = str(value)
nextline = True
else:
text = _to_text(value)
self._eventhandler.add_text(text, nextline)
self._eventhandler.close_tag(key)


Expand Down
7 changes: 4 additions & 3 deletions src/hsd/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def close_tag(self, tagname: str):
self._level -= 1


def add_text(self, text: str):
def add_text(self, text: str, nextline: bool = False):

equal = self._followed_by_equal[-1]
multiline = "\n" in text
multiline = "\n" in text or nextline
if equal is None and not multiline:
if len(self._followed_by_equal) > 1:
equal = not self._followed_by_equal[-2]
Expand All @@ -108,7 +108,8 @@ def add_text(self, text: str):
self._indent_level += 1
indentstr = self._indent_level * _INDENT_STR
self._fobj.write(f" {{\n{indentstr}")
text = text.replace("\n", "\n" + indentstr)
if not nextline:
text = text.replace("\n", "\n" + indentstr)

self._fobj.write(text)
self._fobj.write("\n")
Expand Down
31 changes: 31 additions & 0 deletions src/hsd/interrupts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#--------------------------------------------------------------------------------------------------#
# hsd-python: package for manipulating HSD-formatted data in Python #
# Copyright (C) 2011 - 2022 DFTB+ developers group #
# Licensed under the BSD 2-clause license. #
#--------------------------------------------------------------------------------------------------#
#
"""
Contains hsd interrupts
"""

from hsd.common import unquote

class IncludeText:
"""class for dealing with text interrupts/inclueds"""

def __init__(self, file):
self.file = unquote(file.strip())
self.operator = "<<<"

def __str__(self):
return self.operator + ' "' + self.file + '"'

class IncludeHsd:
"""class for dealing with hsd interrupts/inclueds"""

def __init__(self, file):
self.file = unquote(file.strip())
self.operator = "<<+"

def __str__(self):
return self.operator + ' "' + self.file + '"'
14 changes: 9 additions & 5 deletions src/hsd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
include_hsd_attribs: bool = False, flatten_data: bool = False) -> dict:
include_hsd_attribs: bool = False, flatten_data: bool = False,
include_file: bool = True) -> dict:
"""Loads a file with HSD-formatted data into a Python dictionary
Args:
Expand All @@ -36,6 +37,7 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
flatten_data: Whether multiline data in the HSD input should be
flattened into a single list. Othewise a list of lists is created,
with one list for every line (default).
include_file: Whether files via "<<<"/"<<+" should be included or not
Returns:
Dictionary representing the HSD data.
Expand All @@ -45,7 +47,7 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
"""
dictbuilder = HsdDictBuilder(lower_tag_names=lower_tag_names, flatten_data=flatten_data,
include_hsd_attribs=include_hsd_attribs)
parser = HsdParser(eventhandler=dictbuilder)
parser = HsdParser(eventhandler=dictbuilder, include_file=include_file)
if isinstance(hsdfile, str):
with open(hsdfile, "r") as hsddescr:
parser.parse(hsddescr)
Expand All @@ -56,8 +58,8 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,

def load_string(
hsdstr: str, lower_tag_names: bool = False,
include_hsd_attribs: bool = False, flatten_data: bool = False
) -> dict:
include_hsd_attribs: bool = False, flatten_data: bool = False,
include_file: bool = True) -> dict:
"""Loads a string with HSD-formatted data into a Python dictionary.
Args:
Expand All @@ -75,6 +77,7 @@ def load_string(
flatten_data: Whether multiline data in the HSD input should be
flattened into a single list. Othewise a list of lists is created,
with one list for every line (default).
include_file: Whether files via "<<<"/"<<+" should be included or not
Returns:
Dictionary representing the HSD data.
Expand Down Expand Up @@ -130,7 +133,8 @@ def load_string(
"""
fobj = io.StringIO(hsdstr)
return load(fobj, lower_tag_names, include_hsd_attribs, flatten_data)
return load(fobj, lower_tag_names, include_hsd_attribs, flatten_data,
include_file)


def dump(data: dict, hsdfile: Union[TextIO, str],
Expand Down
18 changes: 14 additions & 4 deletions src/hsd/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Optional, TextIO, Union
from hsd import common
from hsd.eventhandler import HsdEventHandler, HsdEventPrinter

from hsd.interrupts import IncludeHsd, IncludeText

SYNTAX_ERROR = 1
UNCLOSED_TAG_ERROR = 2
Expand Down Expand Up @@ -50,11 +50,13 @@ class HsdParser:
{'Temperature': 100, 'Temperature.attrib': 'Kelvin'}}}}}
"""

def __init__(self, eventhandler: Optional[HsdEventHandler] = None):
def __init__(self, eventhandler: Optional[HsdEventHandler] = None,
include_file: bool = True):
"""Initializes the parser.
Args:
eventhandler: Instance of the HsdEventHandler class or its children.
include_file: Whether files via "<<<"/"<<+" should be included or not
"""
if eventhandler is None:
self._eventhandler = HsdEventPrinter()
Expand All @@ -75,6 +77,7 @@ def __init__(self, eventhandler: Optional[HsdEventHandler] = None):
self._has_child = True # Whether current node has a child already
self._has_text = False # whether current node contains text already
self._oldbefore = "" # buffer for tagname
self._include_file = include_file # Whether files via "<<<"/"<<+" should be included or not


def parse(self, fobj: Union[TextIO, str]):
Expand Down Expand Up @@ -216,10 +219,17 @@ def _parse(self, line):
if txtinc:
self._text("".join(self._buffer) + before)
self._buffer = []
self._eventhandler.add_text(self._include_txt(after[2:]))
if self._include_file:
text = self._include_txt(after[2:])
else:
text = IncludeText(after[2:])
self._eventhandler.add_text(text)
break
if hsdinc:
self._include_hsd(after[2:])
if self._include_file:
self._include_hsd(after[2:])
else:
self._eventhandler.add_text(IncludeHsd(after[2:]))
break
self._buffer.append(before + sign)

Expand Down

0 comments on commit 33f9542

Please sign in to comment.