Coverage for src/nos/data/xdmf_to_torch.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import pathlib
2import xml.etree.ElementTree as ET
4import h5py
5import numpy as np
6import torch
9def get_array(data_dir: pathlib.Path, element: ET.Element) -> torch.tensor:
10 """Gets the first DataItem from a given element.
12 Args:
13 data_dir: directory in which both the xdmf and h5 files are located
14 element: parent
16 Returns:
17 array of all elements stored in the first DataItem for the given element.
18 """
19 data = element.find(".//DataItem")
20 txt = data.text.split(":")
22 h5_file = txt[0]
23 path = txt[1]
25 with h5py.File(data_dir.joinpath(h5_file), "r") as file:
26 d_set = file[path]
27 out_arr = np.empty(d_set.shape)
28 d_set.read_direct(out_arr)
29 out_arr = torch.tensor(out_arr)
31 return out_arr
34def xdmf_to_torch(file: pathlib.Path) -> dict:
35 """Converts a xdmf file to a dictionary with all values stored within it.
37 Args:
38 file: path to the xdmf file (h5 file needs to be located in the same dir).
40 Returns:
41 Dictionary containing topology, geometry, frequencies, and complex values.
42 """
43 tree = ET.parse(file)
44 root = tree.getroot()
46 # topo and geometry
47 geometry = get_array(file.parent, root.find(".//Geometry"))
49 # values
50 fs = root.find('.//Grid[@GridType="Collection"]')
51 values = []
52 encodings = []
53 for f in fs.findall(".//Grid"):
54 encodings.append(float(f.find("Time").attrib["Value"]))
55 real_f = f.find('.//Attribute[@Name="real_f"]')
56 imag_f = f.find('.//Attribute[@Name="imag_f"]')
57 values.append(torch.stack([get_array(file.parent, real_f), get_array(file.parent, imag_f)], dim=1).squeeze())
58 values = torch.stack(values, dim=0)
59 encodings = torch.tensor(encodings)
61 return {
62 "Geometry": geometry,
63 "Values": values,
64 "Encoding": encodings,
65 }