-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loaders.py
251 lines (213 loc) · 8.32 KB
/
loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import importlib.util
import typing as t
import zipimport
from collections import abc
from contextlib import suppress
from importlib import import_module
from pathlib import Path
from aiopath import AsyncPath
from jinja2.environment import Template
from jinja2.exceptions import TemplateNotFound
from jinja2.utils import internalcode
from .environment import AsyncEnvironment
class PackageSpecNotFound(TemplateNotFound):
"""Raised if a package spec not found."""
class LoaderNotFound(TemplateNotFound):
"""Raised if a loader is not found."""
class AsyncBaseLoader:
has_source_access = True
def __init__(self, searchpath: AsyncPath | t.Sequence[AsyncPath]) -> None:
self.searchpath = searchpath
if not isinstance(searchpath, abc.Iterable):
self.searchpath = [searchpath]
async def get_source(self, template: AsyncPath) -> t.Any:
if not self.has_source_access:
raise RuntimeError(
f"{type(self).__name__} cannot provide access to the source"
)
raise TemplateNotFound(template.name)
async def list_templates(self) -> list[str] | t.NoReturn:
raise TypeError("this loader cannot iterate over all templates")
@internalcode
async def load(
self,
environment: AsyncEnvironment,
name: str,
env_globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
) -> Template:
if env_globals is None:
env_globals = {}
code: t.Any = None
bucket: t.Any = None
source, path, uptodate = await self.get_source(AsyncPath(name))
bcc = environment.bytecode_cache
if bcc:
bucket = await bcc.get_bucket(environment, name, path, source)
code = bucket.code
if not code:
code = environment.compile(source, name, path)
if bcc and not bucket.code:
bucket.code = code
await bcc.set_bucket(bucket)
return environment.template_class.from_code(
environment, code, env_globals, uptodate
)
class FileSystemLoader(AsyncBaseLoader):
def __init__(
self,
searchpath: AsyncPath | t.Sequence[AsyncPath],
encoding: str = "utf-8",
followlinks: bool = False,
) -> None:
super().__init__(searchpath)
self.encoding = encoding
self.followlinks = followlinks
async def get_source(self, template: AsyncPath) -> t.Any:
for searchpath in self.searchpath: # type: ignore
path = searchpath / template
if await path.is_file():
break
else:
raise TemplateNotFound(template.name)
try:
resp = await path.read_bytes()
except FileNotFoundError:
raise TemplateNotFound(path.name)
mtime = (await path.stat()).st_mtime
async def uptodate() -> bool:
try:
return (await path.stat()).st_mtime == mtime
except OSError:
return False
return resp, str(path), uptodate
async def list_templates(self) -> list[str]:
results = set()
for searchpath in self.searchpath: # type: ignore
results.update([str(path) async for path in searchpath.rglob("*.html")])
return sorted(results)
class PackageLoader(AsyncBaseLoader):
def __init__(
self,
package_name: str,
searchpath: AsyncPath | t.Sequence[AsyncPath],
package_path: AsyncPath = AsyncPath("templates"),
encoding: str = "utf-8",
) -> None:
super().__init__(searchpath)
self.package_path = package_path
self.package_name = package_name
self.encoding = encoding
import_module(package_name)
spec = importlib.util.find_spec(package_name)
if not spec:
raise PackageSpecNotFound("An import spec was not found for the package")
loader = spec.loader
if not loader:
raise LoaderNotFound("A loader was not found for the package")
self._loader = loader
self._archive = None
template_root = None
if isinstance(loader, zipimport.zipimporter):
self._archive = loader.archive
pkgdir = next(iter(spec.submodule_search_locations)) # type: ignore
template_root = AsyncPath(pkgdir) / package_path
else:
roots = []
if spec.submodule_search_locations:
roots.extend([Path(s) for s in spec.submodule_search_locations])
elif spec.origin is not None:
roots.append(Path(spec.origin))
for root in roots:
path = root / package_path
if path.is_dir():
template_root = AsyncPath(root)
break
if not template_root:
raise ValueError(
f"The {package_name!r} package was not installed in a"
" way that PackageLoader understands"
)
self._template_root = template_root
async def get_source(self, template: AsyncPath) -> t.Any:
path = self._template_root / template
if self._archive:
if not await path.is_file():
raise TemplateNotFound(path.name)
source = await path.read_bytes()
mtime = (await path.stat()).st_mtime
async def uptodate() -> bool:
return await path.is_file() and (await path.stat()).st_mtime == mtime
else:
try:
source = self._loader.get_data(str(path)) # type: ignore
except OSError as e:
raise TemplateNotFound(path.name) from e
uptodate = None # type: ignore
return source.decode(self.encoding), str(path), uptodate # type: ignore
async def list_templates(self) -> list[str]:
results = []
if self._archive is None:
paths = self._template_root.rglob("*.html")
results.extend([str(p) async for p in paths])
else:
if not hasattr(self._loader, "_files"):
raise TypeError(
"This zip import does not have the required"
" metadata to list templates"
)
prefix = self._template_root.name
for name in self._loader._files.keys(): # type: ignore
if name.startswith(prefix) and (await AsyncPath(name).is_file()):
results.append(name)
results.sort()
return results
class DictLoader(AsyncBaseLoader):
def __init__(
self,
mapping: t.Mapping[str, str],
searchpath: AsyncPath | t.Sequence[AsyncPath],
) -> None:
super().__init__(searchpath)
self.mapping = mapping
async def get_source(self, template: AsyncPath) -> t.Any:
if template.name in self.mapping:
source = self.mapping[template.name]
return source, None, lambda: source == self.mapping.get(template.name)
raise TemplateNotFound(template.name)
async def list_templates(self) -> list[str]:
return sorted(self.mapping)
class FunctionLoader(AsyncBaseLoader):
def __init__(
self,
load_func: t.Callable[[AsyncPath], t.Any],
searchpath: AsyncPath | t.Sequence[AsyncPath],
) -> None:
super().__init__(searchpath)
self.load_func = load_func
async def get_source(self, template: str | AsyncPath) -> t.Any:
path = AsyncPath(template)
source = self.load_func(path)
if source is None:
raise TemplateNotFound(path.name)
if isinstance(source, str):
return source, str(path), True
return source
class ChoiceLoader(AsyncBaseLoader):
loaders: list[AsyncBaseLoader] = []
def __init__(
self,
loaders: list[AsyncBaseLoader],
searchpath: AsyncPath | t.Sequence[AsyncPath],
) -> None:
super().__init__(searchpath)
self.loaders = loaders
async def get_source(self, template: AsyncPath) -> t.Any:
for loader in self.loaders:
with suppress(TemplateNotFound):
return await loader.get_source(template)
raise TemplateNotFound(template.name)
async def list_templates(self) -> list[str]:
found = set()
for loader in self.loaders:
found.update(await loader.list_templates())
return sorted(found)