From f530cf40f392de8e082416b7938fc689ae9d5f8d Mon Sep 17 00:00:00 2001
From: Edwin Eefting <edwin@datux.nl>
Date: Tue, 22 Feb 2022 17:40:38 +0100
Subject: [PATCH] fixes. supports stdin

---
 tests/test_check.py        |  2 ++
 zfs_autobackup/ZfsCheck.py | 31 +++++++++++++++++--------------
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/tests/test_check.py b/tests/test_check.py
index fd30231..f7d82c3 100644
--- a/tests/test_check.py
+++ b/tests/test_check.py
@@ -60,6 +60,7 @@ dir/testfile	0	2e863f1fcccd6642e4e28453eba10d2d3f74d798
         #breaks pipe when grep exists:
         #important to use --debug, since that generates extra output which would be problematic if we didnt do correct SIGPIPE handling
         shelltest("python -m zfs_autobackup.ZfsCheck test_source1@test --debug | grep -m1 'Hashing tree'")
+        time.sleep(1)
 
         #should NOT be mounted anymore if cleanup went ok:
         self.assertNotRegex(shelltest("mount"), "test_source1@test")
@@ -73,6 +74,7 @@ dir/testfile	0	2e863f1fcccd6642e4e28453eba10d2d3f74d798
         #breaks pipe when grep exists:
         #important to use --debug, since that generates extra output which would be problematic if we didnt do correct SIGPIPE handling
         shelltest("python -m zfs_autobackup.ZfsCheck test_source1/vol@test --debug | grep -m1 'Hashing dev'")
+        time.sleep(1)
 
         r = shelltest("zfs list -H -o name -r -t all " + TEST_POOLS)
         self.assertMultiLineEqual(r, """
diff --git a/zfs_autobackup/ZfsCheck.py b/zfs_autobackup/ZfsCheck.py
index 4a1290b..8697a61 100644
--- a/zfs_autobackup/ZfsCheck.py
+++ b/zfs_autobackup/ZfsCheck.py
@@ -34,8 +34,8 @@ class ZfsCheck(CliBase):
         group.add_argument('--count', metavar="COUNT", default=int((100 * (1024 ** 2)) / 4096),
                            help="Hash chunks of COUNT blocks. Default %(default)s . (Chunk size is BYTES * COUNT) ", type=int)  # 100MiB
 
-        group.add_argument('--check', '-c', metavar="FILE", default=None,
-                           help="Read hashes from FILE and check them")
+        group.add_argument('--check', '-c', metavar="FILE", default=None, const=True, nargs='?',
+                           help="Read hashes from STDIN (or FILE) and check them")
 
         return parser
 
@@ -138,12 +138,10 @@ class ZfsCheck(CliBase):
 
         snapshot = self.node.get_dataset(self.args.target)
         if not snapshot.exists:
-            snapshot.error("Snapshot not found")
-            sys.exit(1)
+            raise Exception("Snapshot {} not found".format(snapshot))
 
         if not snapshot.is_snapshot:
-            snapshot.error("Dataset should be a snapshot")
-            sys.exit(1)
+            raise Exception("Dataset {} should be a snapshot".format(snapshot))
 
         dataset_type = snapshot.parent.properties['type']
 
@@ -194,12 +192,17 @@ class ZfsCheck(CliBase):
 
     def input_parser(self, file_name):
         """parse input lines and generate items to use in compare functions"""
-        with open(file_name, 'r') as input_fh:
-            for line in input_fh:
-                i=line.rstrip().split("\t")
-                #ignores lines without tabs
-                if (len(i)>1):
-                    yield i
+
+        if self.args.check is True:
+            input_fh=sys.stdin
+        else:
+            input_fh=open(file_name, 'r')
+
+        for line in input_fh:
+            i=line.rstrip().split("\t")
+            #ignores lines without tabs
+            if (len(i)>1):
+                yield i
 
     def run(self):
 
@@ -213,7 +216,7 @@ class ZfsCheck(CliBase):
                         print("{}\t{}".format(*i))
                     sys.stdout.flush()
 
-                sys.exit(0)
+                return 0
 
             #run as compare
             else:
@@ -228,7 +231,7 @@ class ZfsCheck(CliBase):
                             (chunk_nr, compare_hexdigest, actual_hexdigest) = i
                             self.log.error("{}\t{}\t{}".format(chunk_nr, compare_hexdigest, actual_hexdigest))
 
-                sys.exit(min(255,errors))
+                return min(255,errors)
 
         except Exception as e:
             self.error("Exception: " + str(e))