summaryrefslogtreecommitdiff
path: root/servers/hello_world/tests/test_cli.py
blob: 48fa89cbcecdfe47b7fee0d6972927343d3ff44d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pytest
import sys
from unittest.mock import patch, MagicMock
import argparse
from mcp_server_hello_world.cli import run_server, parse_args, validate_args

def test_cli_argument_parsing():
    """Test the argument parsing in the CLI module."""
    # Test default arguments
    with patch('argparse.ArgumentParser.parse_args') as mock_parse_args:
        
        # Set up the mock to return default values
        mock_args = argparse.Namespace(
            transport="stdio",
            host="0.0.0.0",
            port=8080,
            log_level="INFO"
        )
        mock_parse_args.return_value = mock_args
        
        # Call the parse_args function
        args = parse_args()
        
        # Check the arguments
        assert args.transport == "stdio"
        assert args.host == "0.0.0.0"
        assert args.port == 8080
        assert args.log_level == "INFO"
    
    # Test custom arguments
    with patch('argparse.ArgumentParser.parse_args') as mock_parse_args:
        
        # Set up the mock to return custom values
        mock_args = argparse.Namespace(
            transport="remote",
            host="127.0.0.1",
            port=9090,
            log_level="DEBUG"
        )
        mock_parse_args.return_value = mock_args
        
        # Call the parse_args function
        args = parse_args()
        
        # Check the arguments
        assert args.transport == "remote"
        assert args.host == "127.0.0.1"
        assert args.port == 9090
        assert args.log_level == "DEBUG"

def test_run_server():
    """Test the run_server function."""
    with patch('mcp_server_hello_world.cli.parse_args') as mock_parse_args, \
         patch('mcp_server_hello_world.cli.validate_args', return_value=mock_parse_args.return_value) as mock_validate_args, \
         patch('mcp_server_hello_world.cli.setup_logging') as mock_setup_logging, \
         patch('asyncio.run') as mock_run:
        
        # Set up the mock to return default values
        mock_args = argparse.Namespace(
            transport="stdio",
            host="0.0.0.0",
            port=8080,
            log_level="INFO"
        )
        mock_parse_args.return_value = mock_args
        
        # Call the run_server function
        run_server()
        
        # Check that the functions were called with the correct arguments
        mock_parse_args.assert_called_once()
        mock_validate_args.assert_called_once_with(mock_args)
        # The setup_logging function is called with the log_level attribute from the args
        mock_setup_logging.assert_called_once()
        mock_run.assert_called_once()

def test_argument_validation():
    """Test that the argument parser validates arguments correctly."""
    # Test valid transport values
    with patch('sys.argv', ['mcp-server-hello-world', '--transport', 'stdio']):
        parser = argparse.ArgumentParser()
        parser.add_argument('--transport', choices=["stdio", "remote"])
        args = parser.parse_args()
        assert args.transport == "stdio"
    
    with patch('sys.argv', ['mcp-server-hello-world', '--transport', 'remote']):
        parser = argparse.ArgumentParser()
        parser.add_argument('--transport', choices=["stdio", "remote"])
        args = parser.parse_args()
        assert args.transport == "remote"
    
    # Test invalid transport value
    with patch('sys.argv', ['mcp-server-hello-world', '--transport', 'invalid']), \
         patch('sys.stderr', MagicMock()), \
         pytest.raises(SystemExit):
        parser = argparse.ArgumentParser()
        parser.add_argument('--transport', choices=["stdio", "remote"])
        parser.parse_args()
    
    # Test port validation
    with patch('sys.argv', ['mcp-server-hello-world', '--port', '8080']):
        parser = argparse.ArgumentParser()
        parser.add_argument('--port', type=int)
        args = parser.parse_args()
        assert args.port == 8080
    
    # Test invalid port (non-integer)
    with patch('sys.argv', ['mcp-server-hello-world', '--port', 'invalid']), \
         patch('sys.stderr', MagicMock()), \
         pytest.raises(SystemExit):
        parser = argparse.ArgumentParser()
        parser.add_argument('--port', type=int)
        parser.parse_args()