1 module boilerplate.conditions;
2 
3 import std.typecons;
4 
5 version(unittest)
6 {
7     import core.exception : AssertError;
8     import unit_threaded.should;
9 }
10 
11 /++
12 `GenerateInvariants` is a mixin string that automatically generates an `invariant{}` block
13 for each field with a condition.
14 +/
15 public enum string GenerateInvariants = `
16     import boilerplate.conditions : GenerateInvariantsTemplate;
17     mixin GenerateInvariantsTemplate;
18     mixin(typeof(this).generateInvariantsImpl());
19 `;
20 
21 /++
22 When a field is marked with `@NonEmpty`, `!field.empty` is asserted.
23 +/
24 public struct NonEmpty
25 {
26 }
27 
28 ///
29 @("throws when a NonEmpty field is initialized empty")
30 unittest
31 {
32     class Class
33     {
34         @NonEmpty
35         int[] array_;
36 
37         this(int[] array)
38         {
39             this.array_ = array;
40         }
41 
42         mixin(GenerateInvariants);
43     }
44 
45     (new Class(null)).shouldThrow!AssertError;
46 }
47 
48 ///
49 @("throws when a NonEmpty field is assigned empty")
50 unittest
51 {
52     class Class
53     {
54         @NonEmpty
55         private int[] array_;
56 
57         this(int[] array)
58         {
59             this.array_ = array;
60         }
61 
62         public void array(int[] arrayValue)
63         {
64             this.array_ = arrayValue;
65         }
66 
67         mixin(GenerateInvariants);
68     }
69 
70     (new Class([2])).array(null).shouldThrow!AssertError;
71 }
72 
73 /++
74 When a field is marked with `@NonNull`, `field !is null` is asserted.
75 +/
76 public struct NonNull
77 {
78 }
79 
80 ///
81 @("throws when a NonNull field is initialized null")
82 unittest
83 {
84     class Class
85     {
86         @NonNull
87         Object obj_;
88 
89         this(Object obj)
90         {
91             this.obj_ = obj;
92         }
93 
94         mixin(GenerateInvariants);
95     }
96 
97     (new Class(null)).shouldThrow!AssertError;
98 }
99 
100 /++
101 When a field is marked with `@AllNonNull`, `field.all!"a !is null"` is asserted.
102 +/
103 public struct AllNonNull
104 {
105 }
106 
107 ///
108 @("throws when an AllNonNull field is initialized with an array containing null")
109 unittest
110 {
111     class Class
112     {
113         @AllNonNull
114         Object[] objs;
115 
116         this(Object[] objs)
117         {
118             this.objs = objs;
119         }
120 
121         mixin(GenerateInvariants);
122     }
123 
124     (new Class(null)).objs.shouldEqual(Object[].init);
125     (new Class([null])).shouldThrow!AssertError;
126     (new Class([new Object, null])).shouldThrow!AssertError;
127 }
128 
129 /// `@AllNonNull` may be used with associative arrays.
130 @("supports AllNonNull on associative arrays")
131 unittest
132 {
133     class Class
134     {
135         @AllNonNull
136         Object[int] objs;
137 
138         this(Object[int] objs)
139         {
140             this.objs = objs;
141         }
142 
143         mixin(GenerateInvariants);
144     }
145 
146     (new Class(null)).objs.shouldEqual(null);
147     (new Class([0: null])).shouldThrow!AssertError;
148     (new Class([0: new Object, 1: null])).shouldThrow!AssertError;
149 }
150 
151 /// When used with associative arrays, `@AllNonNull` may check keys, values or both.
152 @("supports AllNonNull on associative array keys")
153 unittest
154 {
155     class Class
156     {
157         @AllNonNull
158         int[Object] objs;
159 
160         this(int[Object] objs)
161         {
162             this.objs = objs;
163         }
164 
165         mixin(GenerateInvariants);
166     }
167 
168     (new Class(null)).objs.shouldEqual(null);
169     (new Class([null: 0])).shouldThrow!AssertError;
170     (new Class([new Object: 0, null: 1])).shouldThrow!AssertError;
171 }
172 
173 /++
174 When a field is marked with `@NonInit`, `field !is T.init` is asserted.
175 +/
176 public struct NonInit
177 {
178 }
179 
180 ///
181 @("throws when a NonInit field is initialized with T.init")
182 unittest
183 {
184     import core.time : Duration;
185 
186     class Class
187     {
188         @NonInit
189         float f_;
190 
191         this(float f) { this.f_ = f; }
192 
193         mixin(GenerateInvariants);
194     }
195 
196     (new Class(float.init)).shouldThrow!AssertError;
197 }
198 
199 /++
200 When **any** condition check is applied to a nullable field, the test applies to the value,
201 if any, contained in the field. The "null" state of the field is ignored.
202 +/
203 @("doesn't throw when a Nullable field is null")
204 unittest
205 {
206     class Class
207     {
208         @NonInit
209         Nullable!float f_;
210 
211         this(Nullable!float f)
212         {
213             this.f_ = f;
214         }
215 
216         mixin(GenerateInvariants);
217     }
218 
219     (new Class(5f.nullable)).f_.isNull.shouldBeFalse;
220     (new Class(Nullable!float())).f_.isNull.shouldBeTrue;
221     (new Class(float.init.nullable)).shouldThrow!AssertError;
222 }
223 
224 /++
225 Conditions can be applied to static attributes, generating static invariants.
226 +/
227 @("does not allow invariants on static fields")
228 unittest
229 {
230     static assert(!__traits(compiles, ()
231     {
232         class Class
233         {
234             @NonNull
235             private static Object obj;
236 
237             mixin(GenerateInvariants);
238         }
239     }), "invariant on static field compiled when it shouldn't");
240 }
241 
242 @("works with classes inheriting from templates")
243 unittest
244 {
245     interface I(T)
246     {
247     }
248 
249     interface K(T)
250     {
251     }
252 
253     class C : I!ubyte, K!ubyte
254     {
255     }
256 
257     class S
258     {
259         C c;
260 
261         mixin(GenerateInvariants);
262     }
263 }
264 
265 mixin template GenerateInvariantsTemplate()
266 {
267     private static string generateInvariantsImpl()
268     {
269         if (!__ctfe)
270         {
271             return null;
272         }
273 
274         import boilerplate.conditions : generateChecksForAttributes, IsConditionAttribute;
275         import boilerplate.util : GenNormalMemberTuple, isStatic;
276         import std.format : format;
277         import std.meta : StdMetaFilter = Filter;
278 
279         string result = null;
280 
281         result ~= `invariant {` ~
282             `import std.format : format;` ~
283             `import std.array : empty;`;
284 
285         // TODO blocked by https://issues.dlang.org/show_bug.cgi?id=18504
286         // note: synchronized without lock contention is basically free
287         // IMPORTANT! Do not enable this until you have a solution for reliably detecting which attributes actually
288         // require synchronization! overzealous synchronize has the potential to lead to needless deadlocks.
289         // (consider implementing @GuardedBy)
290         enum synchronize = false;
291 
292         result ~= synchronize ? `synchronized (this) {` : ``;
293 
294         mixin GenNormalMemberTuple;
295 
296         foreach (member; NormalMemberTuple)
297         {
298             mixin(`alias symbol = typeof(this).` ~ member ~ `;`);
299 
300             alias ConditionAttributes = StdMetaFilter!(IsConditionAttribute, __traits(getAttributes, symbol));
301 
302             static if (mixin(isStatic(member)) && ConditionAttributes.length > 0)
303             {
304                 result ~= format!(`static assert(false, `
305                     ~ `"Cannot add constraint on static field %s: no support for static invariants");`
306                 )(member);
307             }
308 
309             static if (__traits(compiles, typeof(symbol).init))
310             {
311                 result ~= generateChecksForAttributes!(typeof(symbol), ConditionAttributes)(`this.` ~ member);
312             }
313         }
314 
315         result ~= synchronize ? ` }` : ``;
316 
317         result ~= ` }`;
318 
319         return result;
320     }
321 }
322 
323 public string generateChecksForAttributes(T, Attributes...)(string memberExpression, string info = "")
324 if (Attributes.length == 0)
325 {
326     return null;
327 }
328 
329 private alias InfoTuple = Tuple!(string, "expr", string, "info", string, "typename");
330 
331 public string generateChecksForAttributes(T, Attributes...)(string memberExpression, string info = "")
332 if (Attributes.length > 0)
333 {
334     import boilerplate.util : formatNamed, udaIndex;
335     import std..string : format;
336     import std.traits : ConstOf;
337 
338     enum isNullable = is(T: Nullable!Args, Args...);
339 
340     static if (isNullable)
341     {
342         enum access = `%s.get`;
343     }
344     else
345     {
346         enum access = `%s`;
347     }
348 
349     alias MemberType = typeof(mixin(isNullable ? `T.init.get` : `T.init`));
350 
351     string expression = isNullable ? (memberExpression ~ `.get`) : memberExpression;
352 
353     auto values = InfoTuple(expression, info, MemberType.stringof);
354 
355     string checks;
356 
357     static if (udaIndex!(NonEmpty, Attributes) != -1)
358     {
359         checks ~= generateNonEmpty!MemberType(values);
360     }
361 
362     static if (udaIndex!(NonNull, Attributes) != -1)
363     {
364         checks ~= generateNonNull!MemberType(values);
365     }
366 
367     static if (udaIndex!(NonInit, Attributes) != -1)
368     {
369         checks ~= generateNonInit!MemberType(values);
370     }
371 
372     static if (udaIndex!(AllNonNull, Attributes) != -1)
373     {
374         checks ~= generateAllNonNull!MemberType(values);
375     }
376 
377     static if (isNullable)
378     {
379         return `if (!` ~ memberExpression ~ `.isNull) {` ~ checks ~ `}`;
380     }
381     else
382     {
383         return checks;
384     }
385 }
386 
387 private string generateNonEmpty(T)(InfoTuple values)
388 {
389     import boilerplate.util : formatNamed;
390     import std.array : empty;
391 
392     string checks;
393 
394     static if (!__traits(compiles, T.init.empty()))
395     {
396         checks ~= formatNamed!`static assert(false, "Cannot call std.array.empty() on '%(expr)'");`.values(values);
397     }
398 
399     enum canFormat = __traits(compiles, format(`%s`, ConstOf!MemberType.init));
400 
401     static if (canFormat)
402     {
403         checks ~= formatNamed!(`assert(!%(expr).empty, `
404             ~ `format("@NonEmpty: assert(!%(expr).empty) failed%(info): %(expr) = %s", %(expr)));`)
405             .values(values);
406     }
407     else
408     {
409         checks ~= formatNamed!`assert(!%(expr).empty(), "@NonEmpty: assert(!%(expr).empty) failed%(info)");`
410             .values(values);
411     }
412 
413     return checks;
414 }
415 
416 private string generateNonNull(T)(InfoTuple values)
417 {
418     import boilerplate.util : formatNamed;
419 
420     string checks;
421 
422     static if (__traits(compiles, T.init.isNull))
423     {
424         checks ~= formatNamed!`assert(!%(expr).isNull, "@NonNull: assert(!%(expr).isNull) failed%(info)");`
425             .values(values);
426     }
427     else static if (__traits(compiles, T.init !is null))
428     {
429         // Nothing good can come of printing something that is null.
430         checks ~= formatNamed!`assert(%(expr) !is null, "@NonNull: assert(%(expr) !is null) failed%(info)");`
431             .values(values);
432     }
433     else
434     {
435         checks ~= formatNamed!`static assert(false, "Cannot compare '%(expr)' to null");`.values(values);
436     }
437 
438     return checks;
439 }
440 
441 private string generateNonInit(T)(InfoTuple values)
442 {
443     import boilerplate.util : formatNamed;
444 
445     string checks;
446 
447     enum canFormat = __traits(compiles, format(`%s`, ConstOf!MemberType.init));
448 
449     static if (!__traits(compiles, T.init !is T.init))
450     {
451         checks ~= formatNamed!`static assert(false, "Cannot compare '%(expr)' to %(typename).init");`
452             .values(values);
453     }
454 
455     static if (canFormat)
456     {
457         checks ~= formatNamed!(`assert(%(expr) !is typeof(%(expr)).init, `
458             ~ `format("@NonInit: assert(%(expr) !is %(typename).init) failed%(info): %(expr) = %s", %(expr)));`)
459             .values(values);
460     }
461     else
462     {
463         checks ~= formatNamed!(`assert(%(expr) !is typeof(%(expr)).init, `
464             ~ `"@NonInit: assert(%(expr) !is %(typename).init) failed%(info)");`)
465             .values(values);
466     }
467 
468     return checks;
469 }
470 
471 private string generateAllNonNull(T)(InfoTuple values)
472 {
473     import boilerplate.util : formatNamed;
474     import std.algorithm : all;
475     import std.traits : isAssociativeArray;
476 
477     string checks;
478 
479     enum canFormat = __traits(compiles, format(`%s`, ConstOf!MemberType.init));
480 
481     checks ~= `import std.algorithm: all;`;
482 
483     static if (__traits(compiles, T.init.all!"a !is null"))
484     {
485         static if (canFormat)
486         {
487             checks ~= formatNamed!(`assert(%(expr).all!"a !is null", format(`
488                 ~ `"@AllNonNull: assert(%(expr).all!\"a !is null\") failed%(info): %(expr) = %s", %(expr)));`)
489                 .values(values);
490         }
491         else
492         {
493             checks ~= formatNamed!(`assert(%(expr).all!"a !is null", `
494                 ~ `"@AllNonNull: assert(%(expr).all!\"a !is null\") failed%(info)");`)
495                 .values(values);
496         }
497     }
498     else static if (__traits(compiles, T.init.all!"!a.isNull"))
499     {
500         static if (canFormat)
501         {
502             checks ~= formatNamed!(`assert(%(expr).all!"!a.isNull", format(`
503                 ~ `"@AllNonNull: assert(%(expr).all!\"!a.isNull\") failed%(info): %(expr) = %s", %(expr)));`)
504                 .values(values);
505         }
506         else
507         {
508             checks ~= formatNamed!(`assert(%(expr).all!"!a.isNull", `
509                 ~ `"@AllNonNull: assert(%(expr).all!\"!a.isNull\") failed%(info)");`)
510                 .values(values);
511         }
512     }
513     else static if (__traits(compiles, isAssociativeArray!T) && isAssociativeArray!T)
514     {
515         enum checkValues = __traits(compiles, T.init.byValue.all!`a !is null`);
516         enum checkKeys = __traits(compiles, T.init.byKey.all!"a !is null");
517 
518         static if (!checkKeys && !checkValues)
519         {
520             checks ~= formatNamed!(`static assert(false, "Neither key nor value of associative array `
521                 ~ `'%(expr)' can be checked against null.");`).values(values);
522         }
523 
524         static if (checkValues)
525         {
526             checks ~=
527                 formatNamed!(`assert(%(expr).byValue.all!"a !is null", `
528                     ~ `"@AllNonNull: assert(%(expr).byValue.all!\"a !is null\") failed%(info)");`)
529                     .values(values);
530         }
531 
532         static if (checkKeys)
533         {
534             checks ~=
535                 formatNamed!(`assert(%(expr).byKey.all!"a !is null", `
536                     ~ `"@AllNonNull: assert(%(expr).byKey.all!\"a !is null\") failed%(info)");`)
537                     .values(values);
538         }
539     }
540     else
541     {
542         checks ~= formatNamed!`static assert(false, "Cannot compare all '%(expr)' to null");`.values(values);
543     }
544 
545     return checks;
546 }
547 
548 public enum IsConditionAttribute(alias A) = __traits(isSame, A, NonEmpty) || __traits(isSame, A, NonNull)
549     || __traits(isSame, A, NonInit) || __traits(isSame, A, AllNonNull);