Coverage for src/nos/data/xdmf_to_torch.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import pathlib 

2import xml.etree.ElementTree as ET 

3 

4import h5py 

5import numpy as np 

6import torch 

7 

8 

9def get_array(data_dir: pathlib.Path, element: ET.Element) -> torch.tensor: 

10 """Gets the first DataItem from a given element. 

11 

12 Args: 

13 data_dir: directory in which both the xdmf and h5 files are located 

14 element: parent 

15 

16 Returns: 

17 array of all elements stored in the first DataItem for the given element. 

18 """ 

19 data = element.find(".//DataItem") 

20 txt = data.text.split(":") 

21 

22 h5_file = txt[0] 

23 path = txt[1] 

24 

25 with h5py.File(data_dir.joinpath(h5_file), "r") as file: 

26 d_set = file[path] 

27 out_arr = np.empty(d_set.shape) 

28 d_set.read_direct(out_arr) 

29 out_arr = torch.tensor(out_arr) 

30 

31 return out_arr 

32 

33 

34def xdmf_to_torch(file: pathlib.Path) -> dict: 

35 """Converts a xdmf file to a dictionary with all values stored within it. 

36 

37 Args: 

38 file: path to the xdmf file (h5 file needs to be located in the same dir). 

39 

40 Returns: 

41 Dictionary containing topology, geometry, frequencies, and complex values. 

42 """ 

43 tree = ET.parse(file) 

44 root = tree.getroot() 

45 

46 # topo and geometry 

47 geometry = get_array(file.parent, root.find(".//Geometry")) 

48 

49 # values 

50 fs = root.find('.//Grid[@GridType="Collection"]') 

51 values = [] 

52 encodings = [] 

53 for f in fs.findall(".//Grid"): 

54 encodings.append(float(f.find("Time").attrib["Value"])) 

55 real_f = f.find('.//Attribute[@Name="real_f"]') 

56 imag_f = f.find('.//Attribute[@Name="imag_f"]') 

57 values.append(torch.stack([get_array(file.parent, real_f), get_array(file.parent, imag_f)], dim=1).squeeze()) 

58 values = torch.stack(values, dim=0) 

59 encodings = torch.tensor(encodings) 

60 

61 return { 

62 "Geometry": geometry, 

63 "Values": values, 

64 "Encoding": encodings, 

65 }