Coverage for gws-app/gws/lib/xmlx/validator.py: 76%
110 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 22:59 +0200
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-16 22:59 +0200
1"""Schema validator."""
3import re
4import os
5import lxml.etree
6import requests
8import gws
11class Error(gws.Error):
12 def __init__(self, *args, **kwargs):
13 super().__init__(*args, **kwargs)
14 self.message = args[0]
15 self.linenoe = args[1]
18def validate(xml: str | bytes):
19 try:
20 parser = lxml.etree.XMLParser(resolve_entities=True)
21 parser.resolvers.add(_CachingResolver())
23 schema_locations = _extract_schema_locations(xml)
24 xsd = _create_combined_xsd(schema_locations)
26 xml_tree = _etree(xml, parser)
27 schema_tree = _etree(xsd, parser)
28 schema = lxml.etree.XMLSchema(schema_tree)
29 except lxml.etree.Error as exc:
30 raise _error(exc) from exc
32 try:
33 schema.assertValid(xml_tree)
34 return True
35 except Exception as exc:
36 raise _error(exc) from exc
39def _extract_schema_locations(xml: str | bytes) -> dict:
40 tree = _etree(xml, None)
41 root = tree.getroot()
43 xsi_ns = '{http://www.w3.org/2001/XMLSchema-instance}'
44 attr = root.get(f'{xsi_ns}schemaLocation')
45 if not attr:
46 attr = root.get('schemaLocation')
47 if not attr:
48 return {}
50 d = {}
52 parts = attr.strip().split()
53 while parts:
54 namespace = parts.pop(0)
55 location = parts.pop(0)
56 d[namespace] = location
58 return d
61def _create_combined_xsd(schema_locations: dict) -> str:
62 xml = []
63 xml.append('<?xml version="1.0" encoding="UTF-8"?>')
64 xml.append('<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">')
66 for ns, loc in schema_locations.items():
67 xml.append(f'<xs:import namespace="{ns}" schemaLocation="{loc}"/>')
69 xml.append('</xs:schema>\n')
71 return '\n'.join(xml)
74def _etree(xml: str | bytes, parser: lxml.etree.XMLParser | None) -> lxml.etree.ElementTree:
75 if isinstance(xml, str):
76 xml = xml.encode('utf-8')
77 return lxml.etree.ElementTree(lxml.etree.fromstring(xml, parser))
80def _error(exc):
81 # exc is either {'message': ..., 'lineno': ...}
82 # or {'error_log': '<string>:17:0:ERROR:...}
84 cls = exc.__class__.__name__
86 s = getattr(exc, 'error_log', None)
87 if s:
88 try:
89 lineno = int(s.split(':')[1])
90 except Exception:
91 lineno = 0
92 return Error(f'{cls}: {s}', lineno)
94 lineno = getattr(exc, 'lineno', 0)
95 return Error(f'{cls}: {exc}', lineno)
98class _CachingResolver(lxml.etree.Resolver):
99 def resolve(self, url, id, context):
100 if url.startswith(('http://', 'https://')):
101 if '.loc' in url or 'local' in url:
102 buf = _download_url(url, with_cache=False)
103 else:
104 buf = _download_url(url, with_cache=True)
105 return self.resolve_string(buf, context, base_url=url)
107 return super().resolve(url, id, context)
110def _download_url(url: str, with_cache: bool) -> bytes:
111 if not with_cache:
112 return _raw_download_url(url)
114 cache_dir = gws.u.ensure_dir(gws.c.CACHE_DIR + '/xmlx')
115 cache_path = _cache_path(cache_dir, url)
117 if os.path.exists(cache_path):
118 return gws.u.read_file_b(cache_path)
120 content = _raw_download_url(url)
121 gws.u.write_file_b(cache_path, content)
122 return content
125def _raw_download_url(url: str) -> bytes:
126 gws.log.debug(f'xmlx.validator: downloading {url!r}')
127 response = requests.get(url, timeout=10)
128 if response.status_code != 200:
129 raise ValueError(f'Failed to download {url!r}: {response.status_code}')
130 return response.content
133def _cache_path(cache_dir: str, url: str) -> str:
134 u = url.strip().split('//')[-1]
135 if '?' in u:
136 u = u.split('?', 1)[0]
137 fname = 'index.xml'
138 parts = u.split('/')
140 if u.endswith('/'):
141 parts.pop()
142 else:
143 m = re.search(r'[^/]+\.[a-z]+$', parts[-1])
144 if m:
145 fname = m.group(0)
146 parts.pop()
148 d = '/'.join(_to_dirname(p) for p in parts)
149 if not d:
150 return cache_dir + '/' + fname
151 d = gws.u.ensure_dir(cache_dir + '/' + d)
152 return d + '/' + fname
155def _to_dirname(s: str) -> str:
156 s = s.lower().strip().lstrip('.')
157 s = re.sub(r'[^a-zA-Z0-9.]+', '_', s).strip('_')
158 return s